A Regression Tree in Dungeons and Dragons
1. Libraries
library(rpart)
library(rpart.plot)
library(forecast)
## Registered S3 method overwritten by 'quantmod':
## method from
## as.zoo.data.frame zoo
library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
2. Load data
<- read.csv("super_heroes_hogwarts_v3a.csv", header = TRUE)
hogwarts head(hogwarts, 10)
## ID Name Gender Race Height Publisher
## 1 A001 A-Bomb Male Human 203 Marvel Comics
## 2 A002 Abe Sapien Male Icthyo Sapien 191 Dark Horse Comics
## 3 A004 Abomination Male Human / Radiation 203 Marvel Comics
## 4 A009 Agent 13 Female <NA> 173 Marvel Comics
## 5 A015 Alex Mercer Male Human NA Wildstorm
## 6 A016 Alex Woolsly Male <NA> NA NBC - Heroes
## 7 A024 Angel Male Vampire NA Dark Horse Comics
## 8 A025 Angel Dust Female Mutant 165 Marvel Comics
## 9 A028 Animal Man Male Human 183 DC Comics
## 10 A032 Anti-Monitor Male God / Eternal 61 DC Comics
## Alignment Weight Manipulative Resourceful Dismissive Intelligent Trusting
## 1 good 441 10 10 7 6 7
## 2 good 65 7 7 6 8 6
## 3 bad 441 6 8 1 6 3
## 4 good 61 7 7 1 9 7
## 5 bad NA 10 6 8 3 4
## 6 good NA 8 10 5 5 6
## 7 good NA 8 6 8 7 4
## 8 good 57 9 8 9 4 1
## 9 good 83 7 6 6 5 8
## 10 bad NA 7 7 7 1 9
## Loyal Stubborn Brave HouseID House STR DEX CON INT WIS CHA Level HP
## 1 7 7 9 1 Slytherin 18 11 17 12 13 11 1 7
## 2 7 6 9 1 Slytherin 16 17 10 13 15 11 8 72
## 3 3 5 2 1 Slytherin 13 14 13 10 18 15 15 135
## 4 4 6 6 1 Slytherin 15 18 16 16 17 10 14 140
## 5 4 1 8 1 Slytherin 14 17 13 12 10 11 9 72
## 6 7 7 6 1 Slytherin 14 14 11 13 12 12 1 8
## 7 1 5 2 1 Slytherin 15 17 15 18 13 18 11 88
## 8 6 5 4 1 Slytherin 8 17 12 15 17 18 1 8
## 9 3 3 2 1 Slytherin 10 17 15 18 13 14 8 56
## 10 1 6 5 1 Slytherin 8 10 11 16 12 11 7 63
str(hogwarts)
## 'data.frame': 734 obs. of 26 variables:
## $ ID : chr "A001" "A002" "A004" "A009" ...
## $ Name : chr "A-Bomb" "Abe Sapien" "Abomination" "Agent 13" ...
## $ Gender : chr "Male" "Male" "Male" "Female" ...
## $ Race : chr "Human" "Icthyo Sapien" "Human / Radiation" NA ...
## $ Height : num 203 191 203 173 NA NA NA 165 183 61 ...
## $ Publisher : chr "Marvel Comics" "Dark Horse Comics" "Marvel Comics" "Marvel Comics" ...
## $ Alignment : chr "good" "good" "bad" "good" ...
## $ Weight : int 441 65 441 61 NA NA NA 57 83 NA ...
## $ Manipulative: int 10 7 6 7 10 8 8 9 7 7 ...
## $ Resourceful : int 10 7 8 7 6 10 6 8 6 7 ...
## $ Dismissive : int 7 6 1 1 8 5 8 9 6 7 ...
## $ Intelligent : int 6 8 6 9 3 5 7 4 5 1 ...
## $ Trusting : int 7 6 3 7 4 6 4 1 8 9 ...
## $ Loyal : int 7 7 3 4 4 7 1 6 3 1 ...
## $ Stubborn : int 7 6 5 6 1 7 5 5 3 6 ...
## $ Brave : int 9 9 2 6 8 6 2 4 2 5 ...
## $ HouseID : int 1 1 1 1 1 1 1 1 1 1 ...
## $ House : chr "Slytherin" "Slytherin" "Slytherin" "Slytherin" ...
## $ STR : int 18 16 13 15 14 14 15 8 10 8 ...
## $ DEX : int 11 17 14 18 17 14 17 17 17 10 ...
## $ CON : int 17 10 13 16 13 11 15 12 15 11 ...
## $ INT : int 12 13 10 16 12 13 18 15 18 16 ...
## $ WIS : int 13 15 18 17 10 12 13 17 13 12 ...
## $ CHA : int 11 11 15 10 11 12 18 18 14 11 ...
## $ Level : int 1 8 15 14 9 1 11 1 8 7 ...
## $ HP : int 7 72 135 140 72 8 88 8 56 63 ...
names(hogwarts)
## [1] "ID" "Name" "Gender" "Race" "Height"
## [6] "Publisher" "Alignment" "Weight" "Manipulative" "Resourceful"
## [11] "Dismissive" "Intelligent" "Trusting" "Loyal" "Stubborn"
## [16] "Brave" "HouseID" "House" "STR" "DEX"
## [21] "CON" "INT" "WIS" "CHA" "Level"
## [26] "HP"
nrow(hogwarts)
## [1] 734
Remove unnecessary variables for this model
<- hogwarts[ , c(19:24, 26)]
hogwarts names(hogwarts)
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"
Look at the new order
t(t(names(hogwarts)))
## [,1]
## [1,] "STR"
## [2,] "DEX"
## [3,] "CON"
## [4,] "INT"
## [5,] "WIS"
## [6,] "CHA"
## [7,] "HP"
str(hogwarts)
## 'data.frame': 734 obs. of 7 variables:
## $ STR: int 18 16 13 15 14 14 15 8 10 8 ...
## $ DEX: int 11 17 14 18 17 14 17 17 17 10 ...
## $ CON: int 17 10 13 16 13 11 15 12 15 11 ...
## $ INT: int 12 13 10 16 12 13 18 15 18 16 ...
## $ WIS: int 13 15 18 17 10 12 13 17 13 12 ...
## $ CHA: int 11 11 15 10 11 12 18 18 14 11 ...
## $ HP : int 7 72 135 140 72 8 88 8 56 63 ...
table(hogwarts$House)
## < table of extent 0 >
nrow(hogwarts)
## [1] 734
3. Training validation split
We’re using our favourite seed number, but you can use any other seed. Note that your solutions may differ slightly with different seeds.
set.seed(666)
<- sample(1:nrow(hogwarts), 0.6 * nrow(hogwarts))
train_index <- setdiff(1:nrow(hogwarts), train_index)
valid_index
<- hogwarts[train_index, ]
train_df <- hogwarts[valid_index, ]
valid_df nrow(train_df)
## [1] 440
nrow(valid_df)
## [1] 294
head(train_df)
## STR DEX CON INT WIS CHA HP
## 574 13 11 12 18 12 13 72
## 638 8 13 18 18 18 15 120
## 608 15 13 17 10 10 12 98
## 123 13 10 11 11 10 17 63
## 540 8 10 18 12 18 17 40
## 654 14 12 12 13 18 11 54
head(valid_df)
## STR DEX CON INT WIS CHA HP
## 2 16 17 10 13 15 11 72
## 3 13 14 13 10 18 15 135
## 5 14 17 13 12 10 11 72
## 12 13 14 12 17 12 11 30
## 13 10 14 15 17 12 16 84
## 14 16 14 14 13 10 12 36
str(train_df)
## 'data.frame': 440 obs. of 7 variables:
## $ STR: int 13 8 15 13 8 14 9 11 15 17 ...
## $ DEX: int 11 13 13 10 10 12 16 18 14 18 ...
## $ CON: int 12 18 17 11 18 12 13 15 11 11 ...
## $ INT: int 18 18 10 11 12 13 14 16 18 11 ...
## $ WIS: int 12 18 10 10 18 18 11 10 18 16 ...
## $ CHA: int 13 15 12 17 17 11 18 11 11 10 ...
## $ HP : int 72 120 98 63 40 54 140 66 56 56 ...
str(valid_df)
## 'data.frame': 294 obs. of 7 variables:
## $ STR: int 16 13 14 13 10 16 11 11 9 11 ...
## $ DEX: int 17 14 17 14 14 14 16 13 11 13 ...
## $ CON: int 10 13 13 12 15 14 18 10 15 12 ...
## $ INT: int 13 10 12 17 17 13 14 13 12 16 ...
## $ WIS: int 15 18 10 12 12 10 15 10 14 15 ...
## $ CHA: int 11 15 11 11 16 12 14 10 13 14 ...
## $ HP : int 72 135 72 30 84 36 140 140 72 21 ...
4. Regression tree
names(train_df)
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"
4.1 Large tree.
This is harder to read, and may not be very useful.
<- rpart(HP ~ STR + DEX + CON + INT + WIS + CHA,
regress_tr data = train_df, method = "anova", maxdepth = 20)
rpart.plot(regress_tr, type = 4)
<- predict(regress_tr, train_df)
predict_train accuracy(predict_train, train_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -2.391647e-15 35.81544 29.87884 -66.19501 93.24161
<- predict(regress_tr, valid_df)
predict_valid accuracy(predict_valid, valid_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -8.012589 37.49316 30.60095 -82.68468 104.7522
4.2 Shallower tree.
A shallower tree may be more useful. But it may be overly simplistic.
<- rpart(HP ~ STR + DEX + CON + INT + WIS + CHA,
regress_tr_shallow data = train_df, method = "anova",
minbucket = 2, maxdepth = 3)
rpart.plot(regress_tr_shallow, type = 4)
head(regress_tr_shallow$where)
## 574 638 608 123 540 654
## 5 5 5 4 4 5
<- predict(regress_tr_shallow, train_df)
predict_train_shallow accuracy(predict_train_shallow, train_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -2.083927e-15 37.61598 32.04895 -75.22096 103.9582
<- predict(regress_tr_shallow, valid_df)
predict_valid_shallow accuracy(predict_valid_shallow, valid_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -7.670169 34.73232 28.2203 -81.03383 99.89169
use training set values as the tree was trained using the training set
library(treeClust)
## Warning: package 'treeClust' was built under R version 4.2.3
## Loading required package: cluster
<- rpart.predict.leaves(regress_tr_shallow, newdata = train_df,
which_node_train type = "where")
head(which_node_train)
## 574 638 608 123 540 654
## 5 5 5 4 4 5
sd of each terminal node
= aggregate(train_df$HP, list(which_node_train), FUN = sd)
sd_node names(sd_node) <- c("Node", "sd")
sd_node
## Node sd
## 1 3 27.41906
## 2 4 35.75392
## 3 5 38.45297
min of each terminal node
= aggregate(train_df$HP, list(which_node_train), FUN = min)
min_node names(min_node) <- c("Node", "min")
min_node
## Node min
## 1 3 8
## 2 4 9
## 3 5 6
max of each terminal node
= aggregate(train_df$HP, list(which_node_train), FUN = max)
max_node names(max_node) <- c("Node", "max")
max_node
## Node max
## 1 3 112
## 2 4 135
## 3 5 150
= aggregate(train_df$HP, list(which_node_train), FUN = mean)
mean_node names(mean_node) <- c("Node", "mean")
mean_node
## Node mean
## 1 3 42.73077
## 2 4 68.75000
## 3 5 71.62176
terminal node stats
library(dplyr)
##
## Attaching package: 'dplyr'
## The following objects are masked from 'package:stats':
##
## filter, lag
## The following objects are masked from 'package:base':
##
## intersect, setdiff, setequal, union
<- sd_node %>%
regress_tr_shallow_node_stats inner_join(min_node, by = "Node") %>%
inner_join(max_node, by = "Node") %>%
inner_join(mean_node, by = "Node")
regress_tr_shallow_node_stats
## Node sd min max mean
## 1 3 27.41906 8 112 42.73077
## 2 4 35.75392 9 135 68.75000
## 3 5 38.45297 6 150 71.62176
<- rpart.predict.leaves(regress_tr_shallow, newdata = valid_df,
which_node_valid type = "where")
head(which_node_valid, 20)
## 2 3 5 12 13 14 15 16 17 18 20 23 24 31 33 40 44 47 48 49
## 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 4 5 5 5 5
<- as.data.frame(which_node_valid)
which_node_valid_df names(which_node_valid_df)[1] <- "node"
head(which_node_valid_df)
## node
## 2 5
## 3 5
## 5 5
## 12 5
## 13 5
## 14 5
5. Predict new record
5.1 One new record
names(hogwarts)
## [1] "STR" "DEX" "CON" "INT" "WIS" "CHA" "HP"
Using the large tree
<- data.frame(STR = 11,
padawan_1 DEX = 17,
CON = 14,
INT = 17,
WIS = 17,
CHA = 13)
<- predict(regress_tr, newdata = padawan_1)
regress_tr_pred regress_tr_pred
## 1
## 57.18182
Using the shallow tree.
<- predict(regress_tr_shallow, newdata = padawan_1)
regress_tr_shallow_pred regress_tr_shallow_pred
## 1
## 71.62176
<- rpart.predict.leaves(regress_tr_shallow, newdata = padawan_1,
padawan_1_node type = "where")
padawan_1_node
## 1
## 5
<- data.frame(Node = padawan_1_node,
padawan_1_pred Prediction = regress_tr_shallow_pred)
padawan_1_pred
## Node Prediction
## 1 5 71.62176
Padawan_1 range of prediction
<- padawan_1_pred %>%
padawan_1_pred_range inner_join(min_node, by = "Node") %>%
inner_join(max_node, by = "Node") %>%
inner_join(sd_node, by = "Node")
padawan_1_pred_range
## Node Prediction min max sd
## 1 5 71.62176 6 150 38.45297
5.2 Multiple new records
<- read.csv("new_dnd_5.csv", header = TRUE)
padawan_2 padawan_2
## STR DEX CON INT WIS CHA
## 1 11 17 14 17 17 13
## 2 6 18 9 10 9 9
## 3 18 17 18 10 10 9
## 4 8 18 13 18 18 16
## 5 18 6 6 6 10 18
Using the large tree
<- predict(regress_tr, newdata = padawan_2)
regress_tr_pred_2 regress_tr_pred_2
## 1 2 3 4 5
## 57.18182 65.65385 93.18750 57.18182 68.75000
Using the shallow tree.
<- predict(regress_tr_shallow, newdata = padawan_2)
regress_tr_shallow_pred_2 regress_tr_shallow_pred_2
## 1 2 3 4 5
## 71.62176 71.62176 71.62176 71.62176 68.75000
<- rpart.predict.leaves(regress_tr_shallow,
padawan_2_node newdata = padawan_2,
type = "where")
padawan_2_node
## 1 2 3 4 5
## 5 5 5 5 4
As a data frame
<- data.frame(Node = padawan_2_node,
padawan_2_pred Prediction = regress_tr_shallow_pred_2)
padawan_2_pred
## Node Prediction
## 1 5 71.62176
## 2 5 71.62176
## 3 5 71.62176
## 4 5 71.62176
## 5 4 68.75000
<- padawan_2_pred %>%
padawan_2_pred_range inner_join(min_node, by = "Node") %>%
inner_join(max_node, by = "Node") %>%
inner_join(sd_node, by = "Node")
padawan_2_pred_range
## Node Prediction min max sd
## 1 5 71.62176 6 150 38.45297
## 2 5 71.62176 6 150 38.45297
## 3 5 71.62176 6 150 38.45297
## 4 5 71.62176 6 150 38.45297
## 5 4 68.75000 9 135 35.75392
6. Improved trees
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:dplyr':
##
## combine
## The following object is masked from 'package:ggplot2':
##
## margin
<- randomForest(HP ~ STR + DEX + CON + INT + WIS + CHA,
rf data = train_df, ntree = 200,
nodesize = 5, importance = TRUE)
<- predict(rf, train_df)
rf_pred_train accuracy(rf_pred_train, train_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -0.1780225 20.62762 17.25481 -40.95754 56.22122
<- predict(rf, valid_df)
rf_pred_valid accuracy(rf_pred_valid, valid_df$HP)
## ME RMSE MAE MPE MAPE
## Test set -8.670002 36.74185 29.91955 -86.16483 106.0162
Prediction using the random forest
<- predict(rf, newdata = padawan_2)
rf_pred_2 rf_pred_2
## 1 2 3 4 5
## 68.46083 71.59617 61.84683 79.77700 66.10200