A Regression Tree in Dungeons and Dragons

Data for demo

Back to the spellbook

1. Libraries

library(rpart)
library(rpart.plot)
library(forecast)
## Registered S3 method overwritten by 'quantmod':
##   method            from
##   as.zoo.data.frame zoo
library(caret)
## Loading required package: ggplot2
## Loading required package: lattice

2. Load data

hogwarts <- read.csv("super_heroes_hogwarts_v3a.csv", header = TRUE)
head(hogwarts, 10)
##      ID         Name Gender              Race Height         Publisher
## 1  A001       A-Bomb   Male             Human    203     Marvel Comics
## 2  A002   Abe Sapien   Male     Icthyo Sapien    191 Dark Horse Comics
## 3  A004  Abomination   Male Human / Radiation    203     Marvel Comics
## 4  A009     Agent 13 Female              <NA>    173     Marvel Comics
## 5  A015  Alex Mercer   Male             Human     NA         Wildstorm
## 6  A016 Alex Woolsly   Male              <NA>     NA      NBC - Heroes
## 7  A024        Angel   Male           Vampire     NA Dark Horse Comics
## 8  A025   Angel Dust Female            Mutant    165     Marvel Comics
## 9  A028   Animal Man   Male             Human    183         DC Comics
## 10 A032 Anti-Monitor   Male     God / Eternal     61         DC Comics
##    Alignment Weight Manipulative Resourceful Dismissive Intelligent Trusting
## 1       good    441           10          10          7           6        7
## 2       good     65            7           7          6           8        6
## 3        bad    441            6           8          1           6        3
## 4       good     61            7           7          1           9        7
## 5        bad     NA           10           6          8           3        4
## 6       good     NA            8          10          5           5        6
## 7       good     NA            8           6          8           7        4
## 8       good     57            9           8          9           4        1
## 9       good     83            7           6          6           5        8
## 10       bad     NA            7           7          7           1        9
##    Loyal Stubborn Brave HouseID     House STR DEX CON INT WIS CHA Level  HP
## 1      7        7     9       1 Slytherin  18  11  17  12  13  11     1   7
## 2      7        6     9       1 Slytherin  16  17  10  13  15  11     8  72
## 3      3        5     2       1 Slytherin  13  14  13  10  18  15    15 135
## 4      4        6     6       1 Slytherin  15  18  16  16  17  10    14 140
## 5      4        1     8       1 Slytherin  14  17  13  12  10  11     9  72
## 6      7        7     6       1 Slytherin  14  14  11  13  12  12     1   8
## 7      1        5     2       1 Slytherin  15  17  15  18  13  18    11  88
## 8      6        5     4       1 Slytherin   8  17  12  15  17  18     1   8
## 9      3        3     2       1 Slytherin  10  17  15  18  13  14     8  56
## 10     1        6     5       1 Slytherin   8  10  11  16  12  11     7  63
str(hogwarts)
## 'data.frame':    734 obs. of  26 variables:
##  $ ID          : chr  "A001" "A002" "A004" "A009" ...
##  $ Name        : chr  "A-Bomb" "Abe Sapien" "Abomination" "Agent 13" ...
##  $ Gender      : chr  "Male" "Male" "Male" "Female" ...
##  $ Race        : chr  "Human" "Icthyo Sapien" "Human / Radiation" NA ...
##  $ Height      : num  203 191 203 173 NA NA NA 165 183 61 ...
##  $ Publisher   : chr  "Marvel Comics" "Dark Horse Comics" "Marvel Comics" "Marvel Comics" ...
##  $ Alignment   : chr  "good" "good" "bad" "good" ...
##  $ Weight      : int  441 65 441 61 NA NA NA 57 83 NA ...
##  $ Manipulative: int  10 7 6 7 10 8 8 9 7 7 ...
##  $ Resourceful : int  10 7 8 7 6 10 6 8 6 7 ...
##  $ Dismissive  : int  7 6 1 1 8 5 8 9 6 7 ...
##  $ Intelligent : int  6 8 6 9 3 5 7 4 5 1 ...
##  $ Trusting    : int  7 6 3 7 4 6 4 1 8 9 ...
##  $ Loyal       : int  7 7 3 4 4 7 1 6 3 1 ...
##  $ Stubborn    : int  7 6 5 6 1 7 5 5 3 6 ...
##  $ Brave       : int  9 9 2 6 8 6 2 4 2 5 ...
##  $ HouseID     : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ House       : chr  "Slytherin" "Slytherin" "Slytherin" "Slytherin" ...
##  $ STR         : int  18 16 13 15 14 14 15 8 10 8 ...
##  $ DEX         : int  11 17 14 18 17 14 17 17 17 10 ...
##  $ CON         : int  17 10 13 16 13 11 15 12 15 11 ...
##  $ INT         : int  12 13 10 16 12 13 18 15 18 16 ...
##  $ WIS         : int  13 15 18 17 10 12 13 17 13 12 ...
##  $ CHA         : int  11 11 15 10 11 12 18 18 14 11 ...
##  $ Level       : int  1 8 15 14 9 1 11 1 8 7 ...
##  $ HP          : int  7 72 135 140 72 8 88 8 56 63 ...
names(hogwarts)
##  [1] "ID"           "Name"         "Gender"       "Race"         "Height"      
##  [6] "Publisher"    "Alignment"    "Weight"       "Manipulative" "Resourceful" 
## [11] "Dismissive"   "Intelligent"  "Trusting"     "Loyal"        "Stubborn"    
## [16] "Brave"        "HouseID"      "House"        "STR"          "DEX"         
## [21] "CON"          "INT"          "WIS"          "CHA"          "Level"       
## [26] "HP"
nrow(hogwarts)
## [1] 734

Remove unnecessary variables for this model

hogwarts <- hogwarts[ , c(19:24, 26)]
names(hogwarts)
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"

Look at the new order

t(t(names(hogwarts)))
##      [,1] 
## [1,] "STR"
## [2,] "DEX"
## [3,] "CON"
## [4,] "INT"
## [5,] "WIS"
## [6,] "CHA"
## [7,] "HP"
str(hogwarts)
## 'data.frame':    734 obs. of  7 variables:
##  $ STR: int  18 16 13 15 14 14 15 8 10 8 ...
##  $ DEX: int  11 17 14 18 17 14 17 17 17 10 ...
##  $ CON: int  17 10 13 16 13 11 15 12 15 11 ...
##  $ INT: int  12 13 10 16 12 13 18 15 18 16 ...
##  $ WIS: int  13 15 18 17 10 12 13 17 13 12 ...
##  $ CHA: int  11 11 15 10 11 12 18 18 14 11 ...
##  $ HP : int  7 72 135 140 72 8 88 8 56 63 ...
table(hogwarts$House)
## < table of extent 0 >
nrow(hogwarts)
## [1] 734

3. Training validation split

We’re using our favourite seed number, but you can use any other seed. Note that your solutions may differ slightly with different seeds.

set.seed(666)


train_index <- sample(1:nrow(hogwarts), 0.6 * nrow(hogwarts))
valid_index <- setdiff(1:nrow(hogwarts), train_index)

train_df <- hogwarts[train_index, ]
valid_df <- hogwarts[valid_index, ]
nrow(train_df)
## [1] 440
nrow(valid_df)
## [1] 294
head(train_df)
##     STR DEX CON INT WIS CHA  HP
## 574  13  11  12  18  12  13  72
## 638   8  13  18  18  18  15 120
## 608  15  13  17  10  10  12  98
## 123  13  10  11  11  10  17  63
## 540   8  10  18  12  18  17  40
## 654  14  12  12  13  18  11  54
head(valid_df)
##    STR DEX CON INT WIS CHA  HP
## 2   16  17  10  13  15  11  72
## 3   13  14  13  10  18  15 135
## 5   14  17  13  12  10  11  72
## 12  13  14  12  17  12  11  30
## 13  10  14  15  17  12  16  84
## 14  16  14  14  13  10  12  36
str(train_df)
## 'data.frame':    440 obs. of  7 variables:
##  $ STR: int  13 8 15 13 8 14 9 11 15 17 ...
##  $ DEX: int  11 13 13 10 10 12 16 18 14 18 ...
##  $ CON: int  12 18 17 11 18 12 13 15 11 11 ...
##  $ INT: int  18 18 10 11 12 13 14 16 18 11 ...
##  $ WIS: int  12 18 10 10 18 18 11 10 18 16 ...
##  $ CHA: int  13 15 12 17 17 11 18 11 11 10 ...
##  $ HP : int  72 120 98 63 40 54 140 66 56 56 ...
str(valid_df)
## 'data.frame':    294 obs. of  7 variables:
##  $ STR: int  16 13 14 13 10 16 11 11 9 11 ...
##  $ DEX: int  17 14 17 14 14 14 16 13 11 13 ...
##  $ CON: int  10 13 13 12 15 14 18 10 15 12 ...
##  $ INT: int  13 10 12 17 17 13 14 13 12 16 ...
##  $ WIS: int  15 18 10 12 12 10 15 10 14 15 ...
##  $ CHA: int  11 15 11 11 16 12 14 10 13 14 ...
##  $ HP : int  72 135 72 30 84 36 140 140 72 21 ...

4. Regression tree

names(train_df)
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"

4.1 Large tree.

This is harder to read, and may not be very useful.

regress_tr <- rpart(HP ~ STR + DEX + CON + INT + WIS + CHA,
                    data = train_df, method = "anova", maxdepth = 20)
rpart.plot(regress_tr, type = 4)

predict_train <- predict(regress_tr, train_df)
accuracy(predict_train, train_df$HP)
##                     ME     RMSE      MAE       MPE     MAPE
## Test set -2.391647e-15 35.81544 29.87884 -66.19501 93.24161
predict_valid <- predict(regress_tr, valid_df)
accuracy(predict_valid, valid_df$HP)
##                 ME     RMSE      MAE       MPE     MAPE
## Test set -8.012589 37.49316 30.60095 -82.68468 104.7522

4.2 Shallower tree.

A shallower tree may be more useful. But it may be overly simplistic.

regress_tr_shallow <- rpart(HP ~ STR + DEX + CON + INT + WIS + CHA,
                            data = train_df, method = "anova", 
                            minbucket = 2, maxdepth = 3)
rpart.plot(regress_tr_shallow, type = 4)

head(regress_tr_shallow$where)
## 574 638 608 123 540 654 
##   5   5   5   4   4   5
predict_train_shallow <- predict(regress_tr_shallow, train_df)
accuracy(predict_train_shallow, train_df$HP)
##                     ME     RMSE      MAE       MPE     MAPE
## Test set -2.083927e-15 37.61598 32.04895 -75.22096 103.9582
predict_valid_shallow <- predict(regress_tr_shallow, valid_df)
accuracy(predict_valid_shallow, valid_df$HP)
##                 ME     RMSE     MAE       MPE     MAPE
## Test set -7.670169 34.73232 28.2203 -81.03383 99.89169

use training set values as the tree was trained using the training set

library(treeClust)
## Warning: package 'treeClust' was built under R version 4.2.3
## Loading required package: cluster
which_node_train <- rpart.predict.leaves(regress_tr_shallow, newdata = train_df,
                                        type = "where")
head(which_node_train)
## 574 638 608 123 540 654 
##   5   5   5   4   4   5

sd of each terminal node

sd_node = aggregate(train_df$HP, list(which_node_train), FUN = sd)
names(sd_node) <- c("Node", "sd")
sd_node
##   Node       sd
## 1    3 27.41906
## 2    4 35.75392
## 3    5 38.45297

min of each terminal node

min_node = aggregate(train_df$HP, list(which_node_train), FUN = min)
names(min_node) <- c("Node", "min")
min_node
##   Node min
## 1    3   8
## 2    4   9
## 3    5   6

max of each terminal node

max_node = aggregate(train_df$HP, list(which_node_train), FUN = max)
names(max_node) <- c("Node", "max")
max_node
##   Node max
## 1    3 112
## 2    4 135
## 3    5 150
mean_node = aggregate(train_df$HP, list(which_node_train), FUN = mean)
names(mean_node) <- c("Node", "mean")
mean_node
##   Node     mean
## 1    3 42.73077
## 2    4 68.75000
## 3    5 71.62176

terminal node stats

library(dplyr)
## 
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
## 
##     filter, lag
## The following objects are masked from 'package:base':
## 
##     intersect, setdiff, setequal, union
regress_tr_shallow_node_stats <- sd_node %>%
  inner_join(min_node, by = "Node") %>%
  inner_join(max_node, by = "Node") %>%
  inner_join(mean_node, by = "Node")

regress_tr_shallow_node_stats
##   Node       sd min max     mean
## 1    3 27.41906   8 112 42.73077
## 2    4 35.75392   9 135 68.75000
## 3    5 38.45297   6 150 71.62176
which_node_valid <- rpart.predict.leaves(regress_tr_shallow, newdata = valid_df,
                                        type = "where")
head(which_node_valid, 20)
##  2  3  5 12 13 14 15 16 17 18 20 23 24 31 33 40 44 47 48 49 
##  5  5  5  5  5  5  5  5  5  5  5  5  5  5  5  4  5  5  5  5
which_node_valid_df <- as.data.frame(which_node_valid)
names(which_node_valid_df)[1] <- "node"

head(which_node_valid_df)
##    node
## 2     5
## 3     5
## 5     5
## 12    5
## 13    5
## 14    5

5. Predict new record

5.1 One new record

names(hogwarts)
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"

Using the large tree

padawan_1 <- data.frame(STR = 11, 
                        DEX = 17, 
                        CON = 14, 
                        INT = 17,
                        WIS = 17,
                        CHA = 13)


regress_tr_pred <- predict(regress_tr, newdata = padawan_1)
regress_tr_pred
##        1 
## 57.18182

Using the shallow tree.

regress_tr_shallow_pred <- predict(regress_tr_shallow, newdata = padawan_1)
regress_tr_shallow_pred
##        1 
## 71.62176
padawan_1_node <- rpart.predict.leaves(regress_tr_shallow, newdata = padawan_1,
                                        type = "where")
padawan_1_node
## 1 
## 5
padawan_1_pred <- data.frame(Node = padawan_1_node,
                             Prediction = regress_tr_shallow_pred)
padawan_1_pred
##   Node Prediction
## 1    5   71.62176

Padawan_1 range of prediction

padawan_1_pred_range <- padawan_1_pred %>%
  inner_join(min_node, by = "Node") %>%
  inner_join(max_node, by = "Node") %>%
  inner_join(sd_node, by = "Node")

padawan_1_pred_range
##   Node Prediction min max       sd
## 1    5   71.62176   6 150 38.45297

5.2 Multiple new records

New data

padawan_2 <- read.csv("new_dnd_5.csv", header = TRUE)
padawan_2
##   STR DEX CON INT WIS CHA
## 1  11  17  14  17  17  13
## 2   6  18   9  10   9   9
## 3  18  17  18  10  10   9
## 4   8  18  13  18  18  16
## 5  18   6   6   6  10  18

Using the large tree

regress_tr_pred_2 <- predict(regress_tr, newdata = padawan_2)
regress_tr_pred_2
##        1        2        3        4        5 
## 57.18182 65.65385 93.18750 57.18182 68.75000

Using the shallow tree.

regress_tr_shallow_pred_2 <- predict(regress_tr_shallow, newdata = padawan_2)
regress_tr_shallow_pred_2
##        1        2        3        4        5 
## 71.62176 71.62176 71.62176 71.62176 68.75000
padawan_2_node <- rpart.predict.leaves(regress_tr_shallow, 
                                       newdata = padawan_2,
                                       type = "where")
padawan_2_node
## 1 2 3 4 5 
## 5 5 5 5 4

As a data frame

padawan_2_pred <- data.frame(Node = padawan_2_node,
                             Prediction = regress_tr_shallow_pred_2)
padawan_2_pred
##   Node Prediction
## 1    5   71.62176
## 2    5   71.62176
## 3    5   71.62176
## 4    5   71.62176
## 5    4   68.75000
padawan_2_pred_range <- padawan_2_pred %>%
  inner_join(min_node, by = "Node") %>%
  inner_join(max_node, by = "Node") %>%
  inner_join(sd_node, by = "Node")

padawan_2_pred_range
##   Node Prediction min max       sd
## 1    5   71.62176   6 150 38.45297
## 2    5   71.62176   6 150 38.45297
## 3    5   71.62176   6 150 38.45297
## 4    5   71.62176   6 150 38.45297
## 5    4   68.75000   9 135 35.75392

6. Improved trees

library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
## 
##     combine
## The following object is masked from 'package:ggplot2':
## 
##     margin
rf <- randomForest(HP ~ STR + DEX + CON + INT + WIS + CHA,
                   data = train_df, ntree = 200,
                   nodesize = 5, importance = TRUE)
rf_pred_train <- predict(rf, train_df)
accuracy(rf_pred_train, train_df$HP)
##                  ME     RMSE      MAE       MPE     MAPE
## Test set -0.1780225 20.62762 17.25481 -40.95754 56.22122
rf_pred_valid <- predict(rf, valid_df)
accuracy(rf_pred_valid, valid_df$HP)
##                 ME     RMSE      MAE       MPE     MAPE
## Test set -8.670002 36.74185 29.91955 -86.16483 106.0162

Prediction using the random forest

rf_pred_2 <- predict(rf, newdata = padawan_2)
rf_pred_2
##        1        2        3        4        5 
## 68.46083 71.59617 61.84683 79.77700 66.10200