A Regression Tree in Dungeons and Dragons
1. Libraries
## Registered S3 method overwritten by 'quantmod':
## method from
## as.zoo.data.frame zoo
## Loading required package: ggplot2
## Loading required package: lattice
2. Load data
## ID Name Gender Race Height Publisher
## 1 A001 A-Bomb Male Human 203 Marvel Comics
## 2 A002 Abe Sapien Male Icthyo Sapien 191 Dark Horse Comics
## 3 A004 Abomination Male Human / Radiation 203 Marvel Comics
## 4 A009 Agent 13 Female <NA> 173 Marvel Comics
## 5 A015 Alex Mercer Male Human NA Wildstorm
## 6 A016 Alex Woolsly Male <NA> NA NBC - Heroes
## 7 A024 Angel Male Vampire NA Dark Horse Comics
## 8 A025 Angel Dust Female Mutant 165 Marvel Comics
## 9 A028 Animal Man Male Human 183 DC Comics
## 10 A032 Anti-Monitor Male God / Eternal 61 DC Comics
## Alignment Weight Manipulative Resourceful Dismissive Intelligent Trusting
## 1 good 441 10 10 7 6 7
## 2 good 65 7 7 6 8 6
## 3 bad 441 6 8 1 6 3
## 4 good 61 7 7 1 9 7
## 5 bad NA 10 6 8 3 4
## 6 good NA 8 10 5 5 6
## 7 good NA 8 6 8 7 4
## 8 good 57 9 8 9 4 1
## 9 good 83 7 6 6 5 8
## 10 bad NA 7 7 7 1 9
## Loyal Stubborn Brave HouseID House STR DEX CON INT WIS CHA Level HP
## 1 7 7 9 1 Slytherin 18 11 17 12 13 11 1 7
## 2 7 6 9 1 Slytherin 16 17 10 13 15 11 8 72
## 3 3 5 2 1 Slytherin 13 14 13 10 18 15 15 135
## 4 4 6 6 1 Slytherin 15 18 16 16 17 10 14 140
## 5 4 1 8 1 Slytherin 14 17 13 12 10 11 9 72
## 6 7 7 6 1 Slytherin 14 14 11 13 12 12 1 8
## 7 1 5 2 1 Slytherin 15 17 15 18 13 18 11 88
## 8 6 5 4 1 Slytherin 8 17 12 15 17 18 1 8
## 9 3 3 2 1 Slytherin 10 17 15 18 13 14 8 56
## 10 1 6 5 1 Slytherin 8 10 11 16 12 11 7 63
## 'data.frame': 734 obs. of 26 variables:
## $ ID : chr "A001" "A002" "A004" "A009" ...
## $ Name : chr "A-Bomb" "Abe Sapien" "Abomination" "Agent 13" ...
## $ Gender : chr "Male" "Male" "Male" "Female" ...
## $ Race : chr "Human" "Icthyo Sapien" "Human / Radiation" NA ...
## $ Height : num 203 191 203 173 NA NA NA 165 183 61 ...
## $ Publisher : chr "Marvel Comics" "Dark Horse Comics" "Marvel Comics" "Marvel Comics" ...
## $ Alignment : chr "good" "good" "bad" "good" ...
## $ Weight : int 441 65 441 61 NA NA NA 57 83 NA ...
## $ Manipulative: int 10 7 6 7 10 8 8 9 7 7 ...
## $ Resourceful : int 10 7 8 7 6 10 6 8 6 7 ...
## $ Dismissive : int 7 6 1 1 8 5 8 9 6 7 ...
## $ Intelligent : int 6 8 6 9 3 5 7 4 5 1 ...
## $ Trusting : int 7 6 3 7 4 6 4 1 8 9 ...
## $ Loyal : int 7 7 3 4 4 7 1 6 3 1 ...
## $ Stubborn : int 7 6 5 6 1 7 5 5 3 6 ...
## $ Brave : int 9 9 2 6 8 6 2 4 2 5 ...
## $ HouseID : int 1 1 1 1 1 1 1 1 1 1 ...
## $ House : chr "Slytherin" "Slytherin" "Slytherin" "Slytherin" ...
## $ STR : int 18 16 13 15 14 14 15 8 10 8 ...
## $ DEX : int 11 17 14 18 17 14 17 17 17 10 ...
## $ CON : int 17 10 13 16 13 11 15 12 15 11 ...
## $ INT : int 12 13 10 16 12 13 18 15 18 16 ...
## $ WIS : int 13 15 18 17 10 12 13 17 13 12 ...
## $ CHA : int 11 11 15 10 11 12 18 18 14 11 ...
## $ Level : int 1 8 15 14 9 1 11 1 8 7 ...
## $ HP : int 7 72 135 140 72 8 88 8 56 63 ...
## [1] "ID" "Name" "Gender" "Race" "Height"
## [6] "Publisher" "Alignment" "Weight" "Manipulative" "Resourceful"
## [11] "Dismissive" "Intelligent" "Trusting" "Loyal" "Stubborn"
## [16] "Brave" "HouseID" "House" "STR" "DEX"
## [21] "CON" "INT" "WIS" "CHA" "Level"
## [26] "HP"
## [1] 734
Remove unnecessary variables for this model
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"
Look at the new order
## [,1]
## [1,] "STR"
## [2,] "DEX"
## [3,] "CON"
## [4,] "INT"
## [5,] "WIS"
## [6,] "CHA"
## [7,] "HP"
## 'data.frame': 734 obs. of 7 variables:
## $ STR: int 18 16 13 15 14 14 15 8 10 8 ...
## $ DEX: int 11 17 14 18 17 14 17 17 17 10 ...
## $ CON: int 17 10 13 16 13 11 15 12 15 11 ...
## $ INT: int 12 13 10 16 12 13 18 15 18 16 ...
## $ WIS: int 13 15 18 17 10 12 13 17 13 12 ...
## $ CHA: int 11 11 15 10 11 12 18 18 14 11 ...
## $ HP : int 7 72 135 140 72 8 88 8 56 63 ...
## < table of extent 0 >
## [1] 734
3. Training validation split
We’re using our favourite seed number, but you can use any other seed. Note that your solutions may differ slightly with different seeds.
set.seed(666)
train_index <- sample(1:nrow(hogwarts), 0.6 * nrow(hogwarts))
valid_index <- setdiff(1:nrow(hogwarts), train_index)
train_df <- hogwarts[train_index, ]
valid_df <- hogwarts[valid_index, ]
nrow(train_df)
## [1] 440
## [1] 294
## STR DEX CON INT WIS CHA HP
## 574 13 11 12 18 12 13 72
## 638 8 13 18 18 18 15 120
## 608 15 13 17 10 10 12 98
## 123 13 10 11 11 10 17 63
## 540 8 10 18 12 18 17 40
## 654 14 12 12 13 18 11 54
## STR DEX CON INT WIS CHA HP
## 2 16 17 10 13 15 11 72
## 3 13 14 13 10 18 15 135
## 5 14 17 13 12 10 11 72
## 12 13 14 12 17 12 11 30
## 13 10 14 15 17 12 16 84
## 14 16 14 14 13 10 12 36
## 'data.frame': 440 obs. of 7 variables:
## $ STR: int 13 8 15 13 8 14 9 11 15 17 ...
## $ DEX: int 11 13 13 10 10 12 16 18 14 18 ...
## $ CON: int 12 18 17 11 18 12 13 15 11 11 ...
## $ INT: int 18 18 10 11 12 13 14 16 18 11 ...
## $ WIS: int 12 18 10 10 18 18 11 10 18 16 ...
## $ CHA: int 13 15 12 17 17 11 18 11 11 10 ...
## $ HP : int 72 120 98 63 40 54 140 66 56 56 ...
## 'data.frame': 294 obs. of 7 variables:
## $ STR: int 16 13 14 13 10 16 11 11 9 11 ...
## $ DEX: int 17 14 17 14 14 14 16 13 11 13 ...
## $ CON: int 10 13 13 12 15 14 18 10 15 12 ...
## $ INT: int 13 10 12 17 17 13 14 13 12 16 ...
## $ WIS: int 15 18 10 12 12 10 15 10 14 15 ...
## $ CHA: int 11 15 11 11 16 12 14 10 13 14 ...
## $ HP : int 72 135 72 30 84 36 140 140 72 21 ...
4. Regression tree
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"
4.1 Large tree.
This is harder to read, and may not be very useful.
regress_tr <- rpart(HP ~ STR + DEX + CON + INT + WIS + CHA,
data = train_df, method = "anova", maxdepth = 20)
rpart.plot(regress_tr, type = 4)
## ME RMSE MAE MPE MAPE
## Test set -2.391647e-15 35.81544 29.87884 -66.19501 93.24161
## ME RMSE MAE MPE MAPE
## Test set -8.012589 37.49316 30.60095 -82.68468 104.7522
4.2 Shallower tree.
A shallower tree may be more useful. But it may be overly simplistic.
regress_tr_shallow <- rpart(HP ~ STR + DEX + CON + INT + WIS + CHA,
data = train_df, method = "anova",
minbucket = 2, maxdepth = 3)
rpart.plot(regress_tr_shallow, type = 4)
## 574 638 608 123 540 654
## 5 5 5 4 4 5
predict_train_shallow <- predict(regress_tr_shallow, train_df)
accuracy(predict_train_shallow, train_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -2.083927e-15 37.61598 32.04895 -75.22096 103.9582
predict_valid_shallow <- predict(regress_tr_shallow, valid_df)
accuracy(predict_valid_shallow, valid_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -7.670169 34.73232 28.2203 -81.03383 99.89169
4.3 Prediction range
use training set values as the tree was trained using the training set
## Loading required package: cluster
which_node_train <- rpart.predict.leaves(regress_tr_shallow, newdata = train_df,
type = "where")
head(which_node_train)
## 574 638 608 123 540 654
## 5 5 5 4 4 5
sd of each terminal node
sd_node = aggregate(train_df$HP, list(which_node_train),FUN = sd)
names(sd_node) <- c("Node", "sd")
sd_node
## Node sd
## 1 3 27.41906
## 2 4 35.75392
## 3 5 38.45297
min of each terminal node
min_node = aggregate(train_df$HP, list(which_node_train), FUN = min)
names(min_node) <- c("Node", "min")
min_node
## Node min
## 1 3 8
## 2 4 9
## 3 5 6
max of each terminal node
max_node = aggregate(train_df$HP, list(which_node_train), FUN = max)
names(max_node) <- c("Node", "max")
max_node
## Node max
## 1 3 112
## 2 4 135
## 3 5 150
mean_node = aggregate(train_df$HP, list(which_node_train), FUN = mean)
names(mean_node) <- c("Node", "mean")
mean_node
## Node mean
## 1 3 42.73077
## 2 4 68.75000
## 3 5 71.62176
terminal node stats
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
regress_tr_shallow_node_stats <- sd_node %>%
inner_join(min_node, by = "Node") %>%
inner_join(max_node, by = "Node") %>%
inner_join(mean_node, by = "Node")
regress_tr_shallow_node_stats
## Node sd min max mean
## 1 3 27.41906 8 112 42.73077
## 2 4 35.75392 9 135 68.75000
## 3 5 38.45297 6 150 71.62176
which_node_valid <- rpart.predict.leaves(regress_tr_shallow, newdata = valid_df,
type = "where")
head(which_node_valid, 20)
## 2 3 5 12 13 14 15 16 17 18 20 23 24 31 33 40 44 47 48 49
## 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 4 5 5 5 5
which_node_valid_df <- as.data.frame(which_node_valid)
names(which_node_valid_df)[1] <- "node"
head(which_node_valid_df)
## node
## 2 5
## 3 5
## 5 5
## 12 5
## 13 5
## 14 5
5. Predict new record
5.1 One new record
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"
Using the large tree
padawan_1 <- data.frame(STR = 11,
DEX = 17,
CON = 14,
INT = 17,
WIS = 17,
CHA = 13)
regress_tr_pred <- predict(regress_tr, newdata = padawan_1)
regress_tr_pred
## 1
## 57.18182
Using the shallow tree.
## 1
## 71.62176
padawan_1_node <- rpart.predict.leaves(regress_tr_shallow, newdata = padawan_1,
type = "where")
padawan_1_node
## 1
## 5
padawan_1_pred <- data.frame(Node = padawan_1_node,
Prediction = regress_tr_shallow_pred)
padawan_1_pred
## Node Prediction
## 1 5 71.62176
Padawan_1 range of prediction
padawan_1_pred_range <- padawan_1_pred %>%
inner_join(min_node, by = "Node") %>%
inner_join(max_node, by = "Node") %>%
inner_join(sd_node, by = "Node")
padawan_1_pred_range
## Node Prediction min max sd
## 1 5 71.62176 6 150 38.45297
5.2 Multiple new records
## STR DEX CON INT WIS CHA
## 1 11 17 14 17 17 13
## 2 6 18 9 10 9 9
## 3 18 17 18 10 10 9
## 4 8 18 13 18 18 16
## 5 18 6 6 6 10 18
Using the large tree
## 1 2 3 4 5
## 57.18182 65.65385 93.18750 57.18182 68.75000
Using the shallow tree.
regress_tr_shallow_pred_2 <- predict(regress_tr_shallow, newdata = padawan_2)
regress_tr_shallow_pred_2
## 1 2 3 4 5
## 71.62176 71.62176 71.62176 71.62176 68.75000
padawan_2_node <- rpart.predict.leaves(regress_tr_shallow,
newdata = padawan_2,
type = "where")
padawan_2_node
## 1 2 3 4 5
## 5 5 5 5 4
As a data frame
padawan_2_pred <- data.frame(Node = padawan_2_node,
Prediction = regress_tr_shallow_pred_2)
padawan_2_pred
## Node Prediction
## 1 5 71.62176
## 2 5 71.62176
## 3 5 71.62176
## 4 5 71.62176
## 5 4 68.75000
padawan_2_pred_range <- padawan_2_pred %>%
inner_join(min_node, by = "Node") %>%
inner_join(max_node, by = "Node") %>%
inner_join(sd_node, by = "Node")
padawan_2_pred_range
## Node Prediction min max sd
## 1 5 71.62176 6 150 38.45297
## 2 5 71.62176 6 150 38.45297
## 3 5 71.62176 6 150 38.45297
## 4 5 71.62176 6 150 38.45297
## 5 4 68.75000 9 135 35.75392
6. Improved trees
## randomForest 4.7-1.2
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
rf <- randomForest(HP ~ STR + DEX + CON + INT + WIS + CHA,
data = train_df, ntree = 200,
nodesize = 5, importance = TRUE)
rf_pred_train <- predict(rf, train_df)
accuracy(rf_pred_train, train_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -0.1780225 20.62762 17.25481 -40.95754 56.22122
## ME RMSE MAE MPE MAPE
## Test set -8.670002 36.74185 29.91955 -86.16483 106.0162
Prediction using the random forest
## 1 2 3 4 5
## 68.46083 71.59617 61.84683 79.77700 66.10200
7. Cross validation
Use caret to perform a 10-fold cross validation; repeated 36 times.
library(randomForest)
rf_2 <- randomForest(HP ~ STR + DEX + CON + INT + WIS + CHA,
data = train_df, ntree = 150,
nodesize = 19, importance = TRUE,
trControl = caret_control_k10,
tuneLength = 15)
## List of 18
## $ call : language randomForest(formula = HP ~ STR + DEX + CON + INT + WIS + CHA, data = train_df, ntree = 150, nodesize = 19, | __truncated__ ...
## $ type : chr "regression"
## $ predicted : Named num [1:440] 65.1 67.7 76.1 61.4 57.4 ...
## ..- attr(*, "names")= chr [1:440] "574" "638" "608" "123" ...
## $ mse : num [1:150] 1733 1815 1748 1738 1779 ...
## $ rsq : num [1:150] -0.186 -0.242 -0.196 -0.19 -0.218 ...
## $ oob.times : int [1:440] 63 55 63 51 63 51 47 63 49 62 ...
## $ importance : num [1:6, 1:2] -18.104 28.31 -0.241 1.642 -20.209 ...
## ..- attr(*, "dimnames")=List of 2
## .. ..$ : chr [1:6] "STR" "DEX" "CON" "INT" ...
## .. ..$ : chr [1:2] "%IncMSE" "IncNodePurity"
## $ importanceSD : Named num [1:6] 11.7 10.4 11.5 10.8 10.3 ...
## ..- attr(*, "names")= chr [1:6] "STR" "DEX" "CON" "INT" ...
## $ localImportance: NULL
## $ proximity : NULL
## $ ntree : num 150
## $ mtry : num 2
## $ forest :List of 11
## ..$ ndbigtree : int [1:150] 83 85 89 95 93 91 87 85 73 83 ...
## ..$ nodestatus : int [1:99, 1:150] -3 -3 -3 -3 -3 -3 -3 -1 -3 -3 ...
## ..$ leftDaughter : int [1:99, 1:150] 2 4 6 8 10 12 14 0 16 18 ...
## ..$ rightDaughter: int [1:99, 1:150] 3 5 7 9 11 13 15 0 17 19 ...
## ..$ nodepred : num [1:99, 1:150] 69.9 63.5 72.4 70.8 46.8 ...
## ..$ bestvar : int [1:99, 1:150] 1 1 2 3 4 6 5 0 6 6 ...
## ..$ xbestsplit : num [1:99, 1:150] 10.5 9.5 10.5 11.5 16.5 16.5 10.5 0 13.5 15 ...
## ..$ ncat : Named int [1:6] 1 1 1 1 1 1
## .. ..- attr(*, "names")= chr [1:6] "STR" "DEX" "CON" "INT" ...
## ..$ nrnodes : int 99
## ..$ ntree : num 150
## ..$ xlevels :List of 6
## .. ..$ STR: num 0
## .. ..$ DEX: num 0
## .. ..$ CON: num 0
## .. ..$ INT: num 0
## .. ..$ WIS: num 0
## .. ..$ CHA: num 0
## $ coefs : NULL
## $ y : Named num [1:440] 72 120 98 63 40 54 140 66 56 56 ...
## ..- attr(*, "names")= chr [1:440] "574" "638" "608" "123" ...
## $ test : NULL
## $ inbag : NULL
## $ terms :Classes 'terms', 'formula' language HP ~ STR + DEX + CON + INT + WIS + CHA
## .. ..- attr(*, "variables")= language list(HP, STR, DEX, CON, INT, WIS, CHA)
## .. ..- attr(*, "factors")= int [1:7, 1:6] 0 1 0 0 0 0 0 0 0 1 ...
## .. .. ..- attr(*, "dimnames")=List of 2
## .. .. .. ..$ : chr [1:7] "HP" "STR" "DEX" "CON" ...
## .. .. .. ..$ : chr [1:6] "STR" "DEX" "CON" "INT" ...
## .. ..- attr(*, "term.labels")= chr [1:6] "STR" "DEX" "CON" "INT" ...
## .. ..- attr(*, "order")= int [1:6] 1 1 1 1 1 1
## .. ..- attr(*, "intercept")= num 0
## .. ..- attr(*, "response")= int 1
## .. ..- attr(*, ".Environment")=<environment: R_GlobalEnv>
## .. ..- attr(*, "predvars")= language list(HP, STR, DEX, CON, INT, WIS, CHA)
## .. ..- attr(*, "dataClasses")= Named chr [1:7] "numeric" "numeric" "numeric" "numeric" ...
## .. .. ..- attr(*, "names")= chr [1:7] "HP" "STR" "DEX" "CON" ...
## - attr(*, "class")= chr [1:2] "randomForest.formula" "randomForest"
## num [1:6, 1:2] -18.104 28.31 -0.241 1.642 -20.209 ...
## - attr(*, "dimnames")=List of 2
## ..$ : chr [1:6] "STR" "DEX" "CON" "INT" ...
## ..$ : chr [1:2] "%IncMSE" "IncNodePurity"
## %IncMSE IncNodePurity
## STR -18.1035332 43809.23
## DEX 28.3097894 45012.56
## CON -0.2405105 45793.86
## INT 1.6423710 40415.85
## WIS -20.2085810 42932.95
## CHA -11.1139418 44524.46
## [1] "matrix" "array"
## %IncMSE IncNodePurity
## STR -18.1035332 43809.23
## DEX 28.3097894 45012.56
## CON -0.2405105 45793.86
## INT 1.6423710 40415.85
## WIS -20.2085810 42932.95
## CHA -11.1139418 44524.46
ggplot(rf_2_impt_df, aes(x = Predictors, y = PercentIncMSE)) +
geom_bar(stat = "identity", fill = "blue") +
labs(x = "Predictors", y = "PercentIncMSE",
title = "Variable Importance by MSE") +
theme_dark()
ggplot(rf_2_impt_df, aes(x = Predictors, y = IncNodePurity)) +
geom_bar(stat = "identity", fill = "green") +
labs(x = "Predictors", y = "IncNodePurity",
title = "Variable Importance by Node Purity") +
theme_dark()
## ME RMSE MAE MPE MAPE
## Test set -0.1484864 30.48996 25.68855 -61.22083 83.944
## ME RMSE MAE MPE MAPE
## Test set -8.061258 35.80621 29.09255 -83.58272 103.035