import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeRegressor
dnd_df = pd.read_csv("super_heroes_dnd_v3a.csv")
dnd_df.head()
ID | Name | Gender | Race | Height | Publisher | Alignment | Weight | STR | DEX | CON | INT | WIS | CHA | Level | HP | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | A001 | A-Bomb | Male | Human | 203.0 | Marvel Comics | good | 441.0 | 18 | 11 | 17 | 12 | 13 | 11 | 1 | 7 |
1 | A002 | Abe Sapien | Male | Icthyo Sapien | 191.0 | Dark Horse Comics | good | 65.0 | 16 | 17 | 10 | 13 | 15 | 11 | 8 | 72 |
2 | A004 | Abomination | Male | Human / Radiation | 203.0 | Marvel Comics | bad | 441.0 | 13 | 14 | 13 | 10 | 18 | 15 | 15 | 135 |
3 | A009 | Agent 13 | Female | NaN | 173.0 | Marvel Comics | good | 61.0 | 15 | 18 | 16 | 16 | 17 | 10 | 14 | 140 |
4 | A015 | Alex Mercer | Male | Human | NaN | Wildstorm | bad | NaN | 14 | 17 | 13 | 12 | 10 | 11 | 9 | 72 |
dnd_df.dtypes
ID object Name object Gender object Race object Height float64 Publisher object Alignment object Weight float64 STR int64 DEX int64 CON int64 INT int64 WIS int64 CHA int64 Level int64 HP int64 dtype: object
pd.DataFrame(dnd_df.columns.values, columns = ["variables"])
variables | |
---|---|
0 | ID |
1 | Name |
2 | Gender |
3 | Race |
4 | Height |
5 | Publisher |
6 | Alignment |
7 | Weight |
8 | STR |
9 | DEX |
10 | CON |
11 | INT |
12 | WIS |
13 | CHA |
14 | Level |
15 | HP |
It's a good idea to get a sense of the target variable
dnd_df["HP"].describe()
count 734.000000 mean 66.885559 std 36.653877 min 6.000000 25% 36.000000 50% 63.000000 75% 91.000000 max 150.000000 Name: HP, dtype: float64
dnd_df_2 = dnd_df.iloc[:, np.r_[8:14, 15]]
dnd_df_2
# Alternatively, use:
# dnd_df.iloc[:, list(range(8,14)) + [15]]
# Note the end range
# Or just use:
# dnd_df.iloc[:, [8, 9, 10, 11, 12, 13, 15]]
# Or use the variable name range
# dnd_df.loc[:, "STR":"HP"]
# Or specify the variable names
# dnd_df.loc[:, ["STR", "DEX", "CON", "INT", "WIS", "CHA", "HP"]]
STR | DEX | CON | INT | WIS | CHA | HP | |
---|---|---|---|---|---|---|---|
0 | 18 | 11 | 17 | 12 | 13 | 11 | 7 |
1 | 16 | 17 | 10 | 13 | 15 | 11 | 72 |
2 | 13 | 14 | 13 | 10 | 18 | 15 | 135 |
3 | 15 | 18 | 16 | 16 | 17 | 10 | 140 |
4 | 14 | 17 | 13 | 12 | 10 | 11 | 72 |
... | ... | ... | ... | ... | ... | ... | ... |
729 | 8 | 14 | 17 | 13 | 14 | 15 | 64 |
730 | 17 | 12 | 11 | 11 | 14 | 10 | 56 |
731 | 18 | 10 | 14 | 17 | 10 | 10 | 49 |
732 | 11 | 11 | 10 | 12 | 15 | 16 | 36 |
733 | 16 | 12 | 18 | 15 | 15 | 16 | 81 |
734 rows × 7 columns
import sklearn
from sklearn.model_selection import train_test_split
predictors = ["STR", "DEX", "CON", "INT", "WIS", "CHA"]
outcome = "HP"
X = dnd_df_2.drop(columns = ["HP"])
y = dnd_df_2["HP"]
train_X, valid_X, train_y, valid_y = train_test_split(X, y, test_size = 0.4, random_state = 666)
train_X.head()
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
650 | 17 | 14 | 16 | 16 | 15 | 17 |
479 | 8 | 18 | 16 | 10 | 14 | 17 |
271 | 9 | 12 | 17 | 10 | 15 | 17 |
647 | 9 | 18 | 16 | 10 | 17 | 13 |
307 | 12 | 16 | 14 | 18 | 15 | 13 |
len(train_X)
440
train_y.head()
650 117 479 120 271 72 647 117 307 100 Name: HP, dtype: int64
len(train_y)
440
valid_X.head()
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
389 | 10 | 16 | 15 | 13 | 11 | 10 |
131 | 18 | 10 | 12 | 10 | 16 | 18 |
657 | 10 | 11 | 12 | 11 | 18 | 14 |
421 | 16 | 13 | 11 | 16 | 13 | 11 |
160 | 12 | 16 | 17 | 18 | 11 | 15 |
len(valid_X)
294
valid_y.head()
389 45 131 42 657 63 421 64 160 54 Name: HP, dtype: int64
len(valid_y)
294
full_tree = DecisionTreeRegressor(random_state = 666)
full_tree
DecisionTreeRegressor(random_state=666)
full_tree_fit = full_tree.fit(train_X, train_y)
Plot the tree
from sklearn import tree
Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.
text_representation = tree.export_text(full_tree, max_depth = 5)
print(text_representation)
|--- feature_1 <= 10.50 | |--- feature_3 <= 14.50 | | |--- feature_3 <= 10.50 | | | |--- feature_5 <= 16.00 | | | | |--- feature_2 <= 13.50 | | | | | |--- value: [18.00] | | | | |--- feature_2 > 13.50 | | | | | |--- feature_0 <= 11.50 | | | | | | |--- value: [48.00] | | | | | |--- feature_0 > 11.50 | | | | | | |--- value: [50.00] | | | |--- feature_5 > 16.00 | | | | |--- value: [9.00] | | |--- feature_3 > 10.50 | | | |--- feature_2 <= 13.50 | | | | |--- feature_0 <= 17.50 | | | | | |--- feature_0 <= 12.00 | | | | | | |--- value: [40.00] | | | | | |--- feature_0 > 12.00 | | | | | | |--- truncated branch of depth 4 | | | | |--- feature_0 > 17.50 | | | | | |--- value: [90.00] | | | |--- feature_2 > 13.50 | | | | |--- feature_0 <= 10.50 | | | | | |--- feature_3 <= 13.50 | | | | | | |--- value: [56.00] | | | | | |--- feature_3 > 13.50 | | | | | | |--- value: [50.00] | | | | |--- feature_0 > 10.50 | | | | | |--- feature_2 <= 17.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_2 > 17.50 | | | | | | |--- value: [120.00] | |--- feature_3 > 14.50 | | |--- feature_5 <= 10.50 | | | |--- feature_0 <= 17.50 | | | | |--- feature_4 <= 16.00 | | | | | |--- value: [112.00] | | | | |--- feature_4 > 16.00 | | | | | |--- value: [84.00] | | | |--- feature_0 > 17.50 | | | | |--- value: [49.00] | | |--- feature_5 > 10.50 | | | |--- feature_4 <= 11.50 | | | | |--- feature_3 <= 16.50 | | | | | |--- feature_4 <= 10.50 | | | | | | |--- value: [6.00] | | | | | |--- feature_4 > 10.50 | | | | | | |--- truncated branch of depth 2 | | | | |--- feature_3 > 16.50 | | | | | |--- feature_4 <= 10.50 | | | | | | |--- value: [54.00] | | | | | |--- feature_4 > 10.50 | | | | | | |--- value: [20.00] | | | |--- feature_4 > 11.50 | | | | |--- feature_4 <= 17.50 | | | | | |--- feature_5 <= 12.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_5 > 12.50 | | | | | | |--- truncated branch of depth 6 | | | | |--- feature_4 > 17.50 | | | | | |--- feature_3 <= 17.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_3 > 17.50 | | | | | | |--- value: [50.00] |--- feature_1 > 10.50 | |--- feature_4 <= 17.50 | | |--- feature_2 <= 17.50 | | | |--- feature_5 <= 10.50 | | | | |--- feature_2 <= 12.50 | | | | | |--- feature_4 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_4 > 11.50 | | | | | | |--- truncated branch of depth 5 | | | | |--- feature_2 > 12.50 | | | | | |--- feature_2 <= 16.50 | | | | | | |--- truncated branch of depth 7 | | | | | |--- feature_2 > 16.50 | | | | | | |--- truncated branch of depth 5 | | | |--- feature_5 > 10.50 | | | | |--- feature_5 <= 17.50 | | | | | |--- feature_3 <= 10.50 | | | | | | |--- truncated branch of depth 10 | | | | | |--- feature_3 > 10.50 | | | | | | |--- truncated branch of depth 13 | | | | |--- feature_5 > 17.50 | | | | | |--- feature_2 <= 15.50 | | | | | | |--- truncated branch of depth 10 | | | | | |--- feature_2 > 15.50 | | | | | | |--- truncated branch of depth 8 | | |--- feature_2 > 17.50 | | | |--- feature_1 <= 15.50 | | | | |--- feature_4 <= 12.50 | | | | | |--- feature_0 <= 16.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_0 > 16.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_4 > 12.50 | | | | | |--- feature_0 <= 17.50 | | | | | | |--- truncated branch of depth 5 | | | | | |--- feature_0 > 17.50 | | | | | | |--- value: [8.00] | | | |--- feature_1 > 15.50 | | | | |--- feature_4 <= 12.50 | | | | | |--- feature_3 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_3 > 11.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_4 > 12.50 | | | | | |--- feature_1 <= 16.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_1 > 16.50 | | | | | | |--- truncated branch of depth 4 | |--- feature_4 > 17.50 | | |--- feature_0 <= 14.50 | | | |--- feature_3 <= 16.50 | | | | |--- feature_3 <= 13.50 | | | | | |--- feature_1 <= 14.50 | | | | | | |--- truncated branch of depth 6 | | | | | |--- feature_1 > 14.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_3 > 13.50 | | | | | |--- feature_1 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_1 > 11.50 | | | | | | |--- truncated branch of depth 5 | | | |--- feature_3 > 16.50 | | | | |--- feature_2 <= 17.00 | | | | | |--- feature_3 <= 17.50 | | | | | | |--- value: [72.00] | | | | | |--- feature_3 > 17.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_2 > 17.00 | | | | | |--- value: [117.00] | | |--- feature_0 > 14.50 | | | |--- feature_0 <= 15.50 | | | | |--- feature_1 <= 14.50 | | | | | |--- feature_5 <= 16.00 | | | | | | |--- value: [9.00] | | | | | |--- feature_5 > 16.00 | | | | | | |--- value: [6.00] | | | | |--- feature_1 > 14.50 | | | | | |--- value: [28.00] | | | |--- feature_0 > 15.50 | | | | |--- feature_5 <= 17.50 | | | | | |--- feature_2 <= 12.50 | | | | | | |--- truncated branch of depth 3 | | | | | |--- feature_2 > 12.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_5 > 17.50 | | | | | |--- value: [112.00]
Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.
tree.plot_tree(full_tree, feature_names = train_X.columns, max_depth = 5)
[Text(0.45454545454545453, 0.9285714285714286, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'), Text(0.1690340909090909, 0.7857142857142857, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'), Text(0.07670454545454546, 0.6428571428571429, 'INT <= 10.5\nsquared_error = 742.63\nsamples = 21\nvalue = 64.476'), Text(0.03409090909090909, 0.5, 'CHA <= 16.0\nsquared_error = 325.688\nsamples = 4\nvalue = 31.25'), Text(0.022727272727272728, 0.35714285714285715, 'CON <= 13.5\nsquared_error = 214.222\nsamples = 3\nvalue = 38.667'), Text(0.011363636363636364, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 18.0'), Text(0.03409090909090909, 0.21428571428571427, 'STR <= 11.5\nsquared_error = 1.0\nsamples = 2\nvalue = 49.0'), Text(0.022727272727272728, 0.07142857142857142, '\n (...) \n'), Text(0.045454545454545456, 0.07142857142857142, '\n (...) \n'), Text(0.045454545454545456, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 9.0'), Text(0.11931818181818182, 0.5, 'CON <= 13.5\nsquared_error = 519.855\nsamples = 17\nvalue = 72.294'), Text(0.09090909090909091, 0.35714285714285715, 'STR <= 17.5\nsquared_error = 187.484\nsamples = 8\nvalue = 58.375'), Text(0.07954545454545454, 0.21428571428571427, 'STR <= 12.0\nsquared_error = 50.98\nsamples = 7\nvalue = 53.857'), Text(0.06818181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.09090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.10227272727272728, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 90.0'), Text(0.14772727272727273, 0.35714285714285715, 'STR <= 10.5\nsquared_error = 490.0\nsamples = 9\nvalue = 84.667'), Text(0.125, 0.21428571428571427, 'INT <= 13.5\nsquared_error = 9.0\nsamples = 2\nvalue = 53.0'), Text(0.11363636363636363, 0.07142857142857142, '\n (...) \n'), Text(0.13636363636363635, 0.07142857142857142, '\n (...) \n'), Text(0.17045454545454544, 0.21428571428571427, 'CON <= 17.5\nsquared_error = 259.061\nsamples = 7\nvalue = 93.714'), Text(0.1590909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.18181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.26136363636363635, 0.6428571428571429, 'CHA <= 10.5\nsquared_error = 824.81\nsamples = 22\nvalue = 40.091'), Text(0.2159090909090909, 0.5, 'STR <= 17.5\nsquared_error = 664.222\nsamples = 3\nvalue = 81.667'), Text(0.20454545454545456, 0.35714285714285715, 'WIS <= 16.0\nsquared_error = 196.0\nsamples = 2\nvalue = 98.0'), Text(0.19318181818181818, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0'), Text(0.2159090909090909, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 84.0'), Text(0.22727272727272727, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 49.0'), Text(0.3068181818181818, 0.5, 'WIS <= 11.5\nsquared_error = 534.144\nsamples = 19\nvalue = 33.526'), Text(0.26136363636363635, 0.35714285714285715, 'INT <= 16.5\nsquared_error = 319.04\nsamples = 5\nvalue = 19.6'), Text(0.23863636363636365, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 2.667\nsamples = 3\nvalue = 8.0'), Text(0.22727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.25, 0.07142857142857142, '\n (...) \n'), Text(0.2840909090909091, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 289.0\nsamples = 2\nvalue = 37.0'), Text(0.2727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.29545454545454547, 0.07142857142857142, '\n (...) \n'), Text(0.3522727272727273, 0.35714285714285715, 'WIS <= 17.5\nsquared_error = 516.964\nsamples = 14\nvalue = 38.5'), Text(0.32954545454545453, 0.21428571428571427, 'CHA <= 12.5\nsquared_error = 382.41\nsamples = 10\nvalue = 46.3'), Text(0.3181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.3409090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.375, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 321.0\nsamples = 4\nvalue = 19.0'), Text(0.36363636363636365, 0.07142857142857142, '\n (...) \n'), Text(0.38636363636363635, 0.07142857142857142, '\n (...) \n'), Text(0.7400568181818182, 0.7857142857142857, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'), Text(0.5795454545454546, 0.6428571428571429, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'), Text(0.48863636363636365, 0.5, 'CHA <= 10.5\nsquared_error = 1409.398\nsamples = 323\nvalue = 69.372'), Text(0.4431818181818182, 0.35714285714285715, 'CON <= 12.5\nsquared_error = 1278.057\nsamples = 43\nvalue = 78.419'), Text(0.42045454545454547, 0.21428571428571427, 'WIS <= 11.5\nsquared_error = 1262.102\nsamples = 14\nvalue = 57.429'), Text(0.4090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.4318181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.4659090909090909, 0.21428571428571427, 'CON <= 16.5\nsquared_error = 970.385\nsamples = 29\nvalue = 88.552'), Text(0.45454545454545453, 0.07142857142857142, '\n (...) \n'), Text(0.4772727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.5340909090909091, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 1415.068\nsamples = 280\nvalue = 67.982'), Text(0.5113636363636364, 0.21428571428571427, 'INT <= 10.5\nsquared_error = 1402.723\nsamples = 241\nvalue = 66.593'), Text(0.5, 0.07142857142857142, '\n (...) \n'), Text(0.5227272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.5568181818181818, 0.21428571428571427, 'CON <= 15.5\nsquared_error = 1405.784\nsamples = 39\nvalue = 76.564'), Text(0.5454545454545454, 0.07142857142857142, '\n (...) \n'), Text(0.5681818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.6704545454545454, 0.5, 'DEX <= 15.5\nsquared_error = 1251.268\nsamples = 36\nvalue = 56.806'), Text(0.625, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 779.741\nsamples = 21\nvalue = 45.857'), Text(0.6022727272727273, 0.21428571428571427, 'STR <= 16.5\nsquared_error = 788.29\nsamples = 10\nvalue = 54.9'), Text(0.5909090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.6136363636363636, 0.07142857142857142, '\n (...) \n'), Text(0.6477272727272727, 0.21428571428571427, 'STR <= 17.5\nsquared_error = 630.05\nsamples = 11\nvalue = 37.636'), Text(0.6363636363636364, 0.07142857142857142, '\n (...) \n'), Text(0.6590909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.7159090909090909, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 1508.649\nsamples = 15\nvalue = 72.133'), Text(0.6931818181818182, 0.21428571428571427, 'INT <= 11.5\nsquared_error = 584.889\nsamples = 6\nvalue = 37.333'), Text(0.6818181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.7045454545454546, 0.07142857142857142, '\n (...) \n'), Text(0.7386363636363636, 0.21428571428571427, 'DEX <= 16.5\nsquared_error = 778.889\nsamples = 9\nvalue = 95.333'), Text(0.7272727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.75, 0.07142857142857142, '\n (...) \n'), Text(0.9005681818181818, 0.6428571428571429, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'), Text(0.8465909090909091, 0.5, 'INT <= 16.5\nsquared_error = 1403.386\nsamples = 26\nvalue = 61.808'), Text(0.8068181818181818, 0.35714285714285715, 'INT <= 13.5\nsquared_error = 1391.959\nsamples = 21\nvalue = 54.429'), Text(0.7840909090909091, 0.21428571428571427, 'DEX <= 14.5\nsquared_error = 1660.628\nsamples = 11\nvalue = 67.909'), Text(0.7727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.7954545454545454, 0.07142857142857142, '\n (...) \n'), Text(0.8295454545454546, 0.21428571428571427, 'DEX <= 11.5\nsquared_error = 676.64\nsamples = 10\nvalue = 39.6'), Text(0.8181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.8409090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.8863636363636364, 0.35714285714285715, 'CON <= 17.0\nsquared_error = 262.16\nsamples = 5\nvalue = 92.8'), Text(0.875, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 144.688\nsamples = 4\nvalue = 86.75'), Text(0.8636363636363636, 0.07142857142857142, '\n (...) \n'), Text(0.8863636363636364, 0.07142857142857142, '\n (...) \n'), Text(0.8977272727272727, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 117.0'), Text(0.9545454545454546, 0.5, 'STR <= 15.5\nsquared_error = 892.889\nsamples = 12\nvalue = 45.667'), Text(0.9318181818181818, 0.35714285714285715, 'DEX <= 14.5\nsquared_error = 76.5\nsamples = 4\nvalue = 13.0'), Text(0.9204545454545454, 0.21428571428571427, 'CHA <= 16.0\nsquared_error = 2.0\nsamples = 3\nvalue = 8.0'), Text(0.9090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.9318181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.9431818181818182, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 28.0'), Text(0.9772727272727273, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 500.75\nsamples = 8\nvalue = 62.0'), Text(0.9659090909090909, 0.21428571428571427, 'CON <= 12.5\nsquared_error = 164.122\nsamples = 7\nvalue = 54.857'), Text(0.9545454545454546, 0.07142857142857142, '\n (...) \n'), Text(0.9772727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.9886363636363636, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0')]
Export tree and convert to a picture file.
from sklearn.tree import export_graphviz
dot_data = export_graphviz(full_tree, out_file='full_tree.dot', feature_names = train_X.columns)
Not very useful.
small_tree = DecisionTreeRegressor(random_state = 666, max_depth = 3, min_samples_split = 25)
small_tree
DecisionTreeRegressor(max_depth=3, min_samples_split=25, random_state=666)
small_tree_fit = small_tree.fit(train_X, train_y)
Plot the tree
# For illustration:
# from sklearn import tree
Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.
text_representation_2 = tree.export_text(small_tree)
print(text_representation_2)
|--- feature_1 <= 10.50 | |--- feature_3 <= 14.50 | | |--- value: [64.48] | |--- feature_3 > 14.50 | | |--- value: [40.09] |--- feature_1 > 10.50 | |--- feature_4 <= 17.50 | | |--- feature_2 <= 17.50 | | | |--- value: [69.37] | | |--- feature_2 > 17.50 | | | |--- value: [56.81] | |--- feature_4 > 17.50 | | |--- feature_0 <= 14.50 | | | |--- value: [61.81] | | |--- feature_0 > 14.50 | | | |--- value: [45.67]
Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.
tree.plot_tree(small_tree, feature_names = train_X.columns)
[Text(0.4090909090909091, 0.875, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'), Text(0.18181818181818182, 0.625, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'), Text(0.09090909090909091, 0.375, 'squared_error = 742.63\nsamples = 21\nvalue = 64.476'), Text(0.2727272727272727, 0.375, 'squared_error = 824.81\nsamples = 22\nvalue = 40.091'), Text(0.6363636363636364, 0.625, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'), Text(0.45454545454545453, 0.375, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'), Text(0.36363636363636365, 0.125, 'squared_error = 1409.398\nsamples = 323\nvalue = 69.372'), Text(0.5454545454545454, 0.125, 'squared_error = 1251.268\nsamples = 36\nvalue = 56.806'), Text(0.8181818181818182, 0.375, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'), Text(0.7272727272727273, 0.125, 'squared_error = 1403.386\nsamples = 26\nvalue = 61.808'), Text(0.9090909090909091, 0.125, 'squared_error = 892.889\nsamples = 12\nvalue = 45.667')]
Export tree and convert to a picture file.
# For illustration
# from sklearn.tree import export_graphviz
dot_data_2 = export_graphviz(small_tree, out_file='small_tree.dot', feature_names = train_X.columns)
Much better.
On the training set
train_y_pred_full = full_tree.predict(train_X)
train_y_pred_full
array([117., 120., 72., 117., 100., 90., 49., 60., 42., 50., 20., 32., 36., 80., 36., 30., 117., 99., 90., 12., 42., 28., 40., 24., 50., 98., 32., 90., 50., 88., 130., 35., 18., 70., 99., 66., 20., 150., 56., 8., 98., 54., 81., 35., 6., 81., 56., 90., 88., 70., 60., 104., 24., 81., 54., 60., 30., 80., 84., 98., 12., 140., 135., 56., 135., 30., 117., 99., 81., 105., 42., 48., 100., 110., 77., 84., 84., 104., 16., 64., 48., 16., 84., 18., 48., 20., 24., 54., 9., 99., 56., 140., 72., 20., 112., 8., 110., 120., 35., 63., 21., 99., 36., 72., 16., 77., 150., 50., 90., 78., 60., 81., 104., 45., 56., 7., 10., 60., 56., 96., 72., 28., 40., 72., 78., 18., 54., 110., 8., 16., 84., 130., 88., 90., 54., 100., 110., 72., 90., 81., 8., 72., 30., 140., 126., 105., 36., 18., 140., 30., 32., 18., 66., 63., 24., 78., 21., 16., 32., 9., 28., 130., 42., 70., 105., 56., 135., 63., 45., 72., 72., 6., 104., 64., 96., 90., 20., 84., 7., 90., 63., 42., 60., 72., 49., 9., 6., 90., 130., 90., 90., 42., 35., 9., 45., 40., 108., 21., 108., 30., 84., 112., 135., 112., 112., 18., 84., 50., 40., 9., 99., 81., 72., 72., 30., 60., 96., 27., 140., 60., 90., 72., 42., 72., 81., 80., 117., 32., 135., 8., 36., 63., 80., 16., 120., 72., 100., 110., 48., 42., 64., 130., 48., 90., 84., 54., 48., 54., 18., 80., 49., 84., 150., 78., 126., 63., 9., 16., 50., 120., 8., 32., 56., 135., 16., 77., 24., 60., 48., 18., 8., 70., 63., 54., 91., 80., 112., 70., 120., 120., 120., 8., 56., 12., 88., 28., 18., 81., 48., 91., 117., 42., 49., 140., 28., 120., 56., 110., 130., 72., 18., 77., 126., 32., 42., 36., 16., 9., 88., 54., 72., 30., 126., 88., 84., 24., 60., 117., 104., 120., 77., 105., 42., 110., 88., 56., 35., 42., 80., 30., 50., 48., 24., 21., 56., 72., 9., 63., 98., 60., 48., 16., 117., 30., 70., 104., 49., 21., 130., 56., 117., 78., 8., 36., 48., 91., 84., 24., 36., 72., 10., 18., 36., 80., 90., 112., 63., 32., 96., 72., 108., 80., 9., 18., 98., 18., 88., 20., 18., 30., 12., 54., 36., 42., 120., 70., 32., 7., 40., 63., 77., 28., 24., 7., 36., 48., 54., 10., 56., 42., 135., 98., 10., 54., 84., 54., 60., 117., 135., 35., 117., 72., 130., 63., 110., 21., 81., 48., 110., 54., 60., 49., 91., 72., 48., 10., 77., 72., 112., 45., 150., 88., 150., 135., 140., 32., 70., 80., 72., 91.])
Get the RMSE for the training set
mse_full_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_full)
mse_full_tree_train
0.0
import math
rmse_full_tree_train = math.sqrt(mse_full_tree_train)
rmse_full_tree_train
0.0
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# Then load the library
# import dmba
# from dmba import regressionSummary
import dmba
from dmba import regressionSummary
regressionSummary(train_y, train_y_pred_full)
Regression statistics Mean Error (ME) : 0.0000 Root Mean Squared Error (RMSE) : 0.0000 Mean Absolute Error (MAE) : 0.0000 Mean Percentage Error (MPE) : 0.0000 Mean Absolute Percentage Error (MAPE) : 0.0000
On the validation set
valid_y_pred_full = full_tree.predict(valid_X)
valid_y_pred_full
array([110., 9., 70., 104., 84., 36., 150., 35., 90., 20., 32., 130., 117., 96., 120., 21., 70., 35., 18., 54., 64., 27., 32., 72., 126., 88., 30., 20., 40., 9., 126., 112., 21., 54., 110., 9., 21., 70., 140., 110., 72., 36., 117., 105., 72., 18., 8., 54., 81., 40., 36., 135., 90., 112., 72., 91., 110., 63., 100., 27., 110., 7., 30., 90., 40., 16., 32., 104., 48., 90., 12., 140., 36., 135., 16., 18., 60., 96., 84., 54., 30., 42., 80., 28., 30., 16., 130., 8., 90., 24., 40., 117., 49., 140., 20., 42., 72., 90., 108., 42., 60., 80., 135., 48., 42., 84., 91., 72., 84., 36., 50., 50., 56., 126., 28., 42., 80., 36., 96., 36., 8., 72., 10., 72., 77., 10., 135., 54., 140., 16., 140., 40., 9., 63., 8., 54., 150., 63., 56., 77., 84., 140., 135., 9., 84., 36., 8., 91., 16., 42., 8., 110., 56., 36., 20., 18., 63., 117., 81., 98., 81., 84., 120., 36., 8., 27., 130., 12., 63., 9., 30., 84., 130., 16., 56., 90., 21., 63., 32., 40., 40., 140., 60., 84., 18., 100., 18., 126., 36., 150., 80., 77., 16., 56., 88., 110., 18., 8., 18., 96., 42., 30., 21., 110., 54., 60., 42., 135., 126., 130., 72., 117., 72., 21., 90., 70., 48., 117., 24., 60., 45., 91., 8., 6., 130., 81., 50., 66., 96., 130., 63., 63., 8., 78., 70., 63., 9., 72., 18., 80., 54., 72., 56., 40., 84., 60., 70., 49., 64., 30., 54., 28., 72., 104., 78., 18., 56., 54., 6., 77., 54., 120., 56., 54., 45., 56., 126., 63., 72., 80., 50., 63., 117., 63., 112., 90., 117., 36., 42., 90., 110., 135., 56., 60., 20., 56., 81., 84., 150., 9., 120., 42., 72., 21.])
Get the RMSE for the validation set
mse_full_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_full)
mse_full_tree_valid
2643.622448979592
import math
rmse_full_tree_valid = math.sqrt(mse_full_tree_valid)
rmse_full_tree_valid
51.41616913947977
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# Then load the library
# import dmba
# from dmba import regressionSummary
regressionSummary(valid_y, valid_y_pred_full)
Regression statistics Mean Error (ME) : 3.7041 Root Mean Squared Error (RMSE) : 51.4162 Mean Absolute Error (MAE) : 41.2415 Mean Percentage Error (MPE) : -47.2748 Mean Absolute Percentage Error (MAPE) : 99.1131
On the training set
train_y_pred = small_tree.predict(train_X)
train_y_pred
array([69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 40.09090909, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 56.80555556, 61.80769231, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 64.47619048, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 64.47619048, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 61.80769231, 40.09090909, 69.37151703, 64.47619048, 64.47619048, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 40.09090909, 64.47619048, 69.37151703, 61.80769231, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 45.66666667, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703, 61.80769231, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 45.66666667, 69.37151703, 61.80769231, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703])
Get the RMSE for the training set
mse_small_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred)
mse_small_tree_train
1320.9654960245605
import math
rmse_small_tree_train = math.sqrt(mse_small_tree_train)
rmse_small_tree_train
36.34508902210256
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# Then load the library
# import dmba
# from dmba import regressionSummary
import dmba
from dmba import regressionSummary
regressionSummary(train_y, train_y_pred)
Regression statistics Mean Error (ME) : 0.0000 Root Mean Squared Error (RMSE) : 36.3451 Mean Absolute Error (MAE) : 30.4192 Mean Percentage Error (MPE) : -75.0378 Mean Absolute Percentage Error (MAPE) : 103.4166
On the validation set
valid_y_pred = small_tree.predict(valid_X)
valid_y_pred
array([69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 64.47619048, 61.80769231, 56.80555556, 69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 61.80769231, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 64.47619048, 64.47619048, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 45.66666667, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 64.47619048, 69.37151703, 45.66666667, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 40.09090909, 61.80769231, 69.37151703, 69.37151703, 61.80769231, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 64.47619048, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 69.37151703, 64.47619048, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 40.09090909, 69.37151703, 45.66666667, 64.47619048, 69.37151703, 56.80555556, 69.37151703])
Get the RMSE for the validation set
mse_small_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred)
mse_small_tree_valid
1353.8427521618253
import math
rmse_small_tree_valid = math.sqrt(mse_small_tree_valid)
rmse_small_tree_valid
36.79460221502368
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# Then load the library
# import dmba
# from dmba import regressionSummary
regressionSummary(valid_y, valid_y_pred)
Regression statistics Mean Error (ME) : 4.2030 Root Mean Squared Error (RMSE) : 36.7946 Mean Absolute Error (MAE) : 30.2375 Mean Percentage Error (MPE) : -48.2303 Mean Absolute Percentage Error (MAPE) : 80.0992
train_y.describe()
count 440.000000 mean 65.552273 std 37.217785 min 6.000000 25% 35.000000 50% 63.000000 75% 90.000000 max 150.000000 Name: HP, dtype: float64
new_dnd_df = pd.read_csv("new_records_dnd.csv")
new_dnd_df
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 |
1 | 17 | 9 | 17 | 18 | 11 | 7 |
Using the small tree
new_records_dnd_small_pred = small_tree.predict(new_dnd_df)
new_records_dnd_small_pred
array([69.37151703, 40.09090909])
import pandas as pd
dnd_small_tree_prediction_df = pd.DataFrame(new_records_dnd_small_pred,
columns = ["Prediction"])
dnd_small_tree_prediction_df
Prediction | |
---|---|
0 | 69.371517 |
1 | 40.090909 |
Merge with new data
new_dnd_df_with_prediction = pd.concat((new_dnd_df, dnd_small_tree_prediction_df), axis = 1)
new_dnd_df_with_prediction
# to export
# new_dnd_df_with_prediction.to_csv("whatever_name.csv")
STR | DEX | CON | INT | WIS | CHA | Prediction | |
---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 69.371517 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 40.090909 |
Get the leaf number
leaf_number_for_new = small_tree.apply(new_dnd_df)
leaf_number_for_new
array([6, 3], dtype=int64)
leaf_number_for_new_df = pd.DataFrame(leaf_number_for_new, columns = ["leaf_number"])
leaf_number_for_new_df
leaf_number | |
---|---|
0 | 6 |
1 | 3 |
new_dnd_df_with_prediction_small_tree_leaf_number = pd.concat((new_dnd_df_with_prediction,
leaf_number_for_new_df),
axis = 1)
new_dnd_df_with_prediction_small_tree_leaf_number
STR | DEX | CON | INT | WIS | CHA | Prediction | leaf_number | |
---|---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 69.371517 | 6 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 40.090909 | 3 |
Get the values of each leaf
leaf_number = pd.DataFrame(small_tree.apply(train_X), columns = ["leaf_number"], index = train_y.index)
leaf_number
leaf_number | |
---|---|
650 | 6 |
479 | 6 |
271 | 6 |
647 | 6 |
307 | 6 |
... | ... |
445 | 6 |
414 | 6 |
70 | 6 |
429 | 6 |
236 | 6 |
440 rows × 1 columns
Get the HP of each record and corresponding leaf assignment
leaf_df = pd.concat([leaf_number, train_y], axis = 1)
leaf_df
leaf_number | HP | |
---|---|---|
650 | 6 | 117 |
479 | 6 | 120 |
271 | 6 | 72 |
647 | 6 | 117 |
307 | 6 | 100 |
... | ... | ... |
445 | 6 | 32 |
414 | 6 | 70 |
70 | 6 | 80 |
429 | 6 | 72 |
236 | 6 | 91 |
440 rows × 2 columns
Various descriptive stats of each leaf
leaf_max_df = leaf_df.groupby(by = "leaf_number").max()
leaf_max_df
HP | |
---|---|
leaf_number | |
2 | 120 |
3 | 112 |
6 | 150 |
7 | 140 |
9 | 140 |
10 | 112 |
leaf_max_df = leaf_max_df.rename(columns = {"HP": "Max_HP"})
leaf_max_df
Max_HP | |
---|---|
leaf_number | |
2 | 120 |
3 | 112 |
6 | 150 |
7 | 140 |
9 | 140 |
10 | 112 |
leaf_min_df = leaf_df.groupby(by = "leaf_number").min()
leaf_min_df
HP | |
---|---|
leaf_number | |
2 | 9 |
3 | 6 |
6 | 7 |
7 | 7 |
9 | 6 |
10 | 6 |
leaf_min_df = leaf_min_df.rename(columns = {"HP": "Min_HP"})
leaf_min_df
Min_HP | |
---|---|
leaf_number | |
2 | 9 |
3 | 6 |
6 | 7 |
7 | 7 |
9 | 6 |
10 | 6 |
leaf_std_df = leaf_df.groupby(by = "leaf_number").std()
leaf_std_df
HP | |
---|---|
leaf_number | |
2 | 27.924217 |
3 | 29.395350 |
6 | 37.600194 |
7 | 35.875037 |
9 | 38.203685 |
10 | 31.209944 |
leaf_std_df = leaf_std_df.rename(columns = {"HP": "std_HP"})
leaf_std_df
std_HP | |
---|---|
leaf_number | |
2 | 27.924217 |
3 | 29.395350 |
6 | 37.600194 |
7 | 35.875037 |
9 | 38.203685 |
10 | 31.209944 |
Put them all together
new_dnd_df_with_prediction_small_tree_leaf_number_range = pd.merge(
pd.merge(
pd.merge(new_dnd_df_with_prediction_small_tree_leaf_number,leaf_max_df, how = "inner", on = "leaf_number"),
leaf_min_df, how = "inner", on = "leaf_number"),
leaf_std_df, how = "inner", on = "leaf_number")
new_dnd_df_with_prediction_small_tree_leaf_number_range
STR | DEX | CON | INT | WIS | CHA | Prediction | leaf_number | Max_HP | Min_HP | std_HP | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 69.371517 | 6 | 150 | 7 | 37.600194 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 40.090909 | 3 | 112 | 6 | 29.395350 |
from sklearn.ensemble import RandomForestRegressor
rf = RandomForestRegressor(max_depth = 10, random_state = 666)
rf.fit(train_X, train_y)
RandomForestRegressor(max_depth=10, random_state=666)
train_y_pred_rf = rf.predict(train_X)
train_y_pred_rf
array([ 88.93415507, 103.4440601 , 69.47537521, 102.25756247, 90.6015338 , 72.34565388, 55.08351732, 61.08039394, 56.41217774, 55.12831922, 44.96499255, 50.06883454, 49.50479846, 71.23535818, 62.30711203, 38.00342379, 96.4850522 , 88.06 , 74.44247935, 39.28533981, 65.46959216, 35.64007071, 45.80873105, 45.26427884, 49.58661111, 81.29242571, 49.98355509, 85.12770696, 56.05009987, 79.14899567, 82.05170197, 48.303423 , 43.96298786, 63.78496467, 82.41620807, 70.45749862, 47.52008668, 115.36753846, 47.96596143, 32.5633898 , 75.86841179, 59.54143651, 77.72521177, 51.08989052, 21.19936315, 77.51343015, 62.72026195, 78.9389704 , 72.71332005, 77.33110815, 56.42808742, 89.85577218, 46.64963591, 85.13207896, 54.96862511, 56.87054614, 43.91997655, 81.48412106, 72.74496673, 86.26800778, 41.57996127, 99.21362527, 108.10668571, 57.29779173, 111.28059477, 53.60430871, 82.44322777, 82.40620381, 77.08339146, 91.56549427, 63.02749817, 54.35300364, 82.2662034 , 97.93201389, 73.0482379 , 75.82030556, 77.8802475 , 72.55873069, 37.73159091, 71.44797238, 54.50111558, 28.51136153, 81.70698526, 41.82350859, 54.03108768, 45.3973026 , 45.78339763, 61.36807 , 30.11053968, 82.13027778, 61.15893505, 101.57288501, 67.48741823, 26.93542019, 88.02408402, 37.45897826, 85.53040787, 84.95445402, 47.64463051, 67.75539365, 37.25539184, 91.29050568, 48.96801172, 78.61864337, 33.39780771, 77.58149604, 118.14853662, 55.76287724, 68.37750732, 69.79108469, 63.83683632, 66.15682478, 87.99372067, 54.63183916, 60.71442133, 33.09455779, 30.96605128, 60.4653631 , 65.59826143, 81.64426887, 60.19963427, 45.11525176, 49.20249957, 74.16530098, 73.19454255, 39.59228974, 54.49690542, 95.66990974, 35.73416306, 39.08250008, 79.12853771, 112.28103497, 68.98634541, 80.03228299, 67.99279785, 82.48577546, 85.0130968 , 70.083967 , 74.94647463, 72.72633574, 43.42748338, 82.81276538, 45.62634505, 106.61943651, 102.58836559, 80.38613257, 57.37117338, 33.88748496, 108.94866082, 53.35486688, 56.18110066, 34.60665282, 78.43217183, 64.39763654, 45.69437607, 68.45122386, 52.6406046 , 32.45174725, 48.9233243 , 29.83394958, 35.42764936, 109.65492857, 51.11937884, 69.64900305, 89.00004645, 58.59010944, 107.6301434 , 63.54540565, 52.22194841, 71.08996857, 67.97638999, 27.7934586 , 80.31229295, 71.33814287, 77.0330639 , 69.61092666, 48.67007949, 85.91442843, 41.74883263, 75.18138508, 77.07730257, 53.18972691, 66.60794683, 65.06615015, 62.96954681, 32.4638915 , 25.73599206, 69.36473102, 80.72062339, 69.3157973 , 79.51272729, 50.00133222, 53.77316321, 52.01618708, 51.0263894 , 49.63979731, 79.75040143, 42.25055932, 87.69675137, 64.09286381, 72.61794547, 92.84779201, 98.12629072, 86.48315375, 90.06800595, 57.20494978, 81.81610457, 51.62287151, 50.15992208, 32.62498061, 92.94263059, 68.53155923, 66.82723705, 70.01569548, 51.41607791, 62.33188215, 71.4641899 , 43.85667954, 105.44633669, 57.67718836, 77.16578731, 74.30349661, 53.70944274, 71.63464022, 75.0198326 , 69.05771916, 88.68029785, 44.97268484, 107.37958222, 23.36269481, 49.00191044, 58.6598738 , 78.08707725, 35.00298126, 84.92576119, 76.44292787, 94.72199013, 94.11666398, 59.28519292, 60.45311085, 65.7338112 , 95.74016892, 58.27571061, 77.21380488, 67.20433333, 49.15215437, 54.62181655, 55.11792308, 34.05189765, 64.52346018, 56.13229827, 81.13058491, 113.66492836, 76.42701323, 88.65096755, 57.28039899, 38.63262642, 48.68805405, 59.04782569, 102.37085767, 31.38396037, 52.23307442, 53.99473035, 80.48299547, 52.97368403, 68.96615525, 45.42954444, 61.45521545, 46.90742883, 53.06130937, 47.32722966, 76.63714672, 53.59161405, 63.17177048, 87.5244256 , 71.14728393, 89.9188027 , 68.04064219, 81.02650193, 108.56730092, 95.20197455, 40.77743978, 60.0310126 , 33.02766208, 75.63861591, 49.40175669, 33.68347467, 67.28065404, 52.59024541, 90.46681485, 96.96412321, 55.5304823 , 67.5809228 , 101.79697451, 43.97335836, 99.45552434, 69.62235006, 90.65557188, 109.08199052, 63.41604082, 38.49531318, 72.52384339, 90.26417945, 42.18824789, 48.05441438, 47.9575858 , 57.17206151, 30.2311224 , 71.52939374, 62.88909123, 68.43390846, 43.77445167, 95.76515043, 74.0386496 , 70.54859477, 55.06728316, 59.3325523 , 97.05169676, 80.23437332, 99.06046429, 66.32816194, 89.5258736 , 61.41168678, 98.17453733, 76.69305752, 60.18673817, 54.82025641, 57.32873945, 76.85930604, 45.96900252, 58.76426003, 55.51669719, 48.82278961, 43.78012896, 58.54019432, 72.18893748, 32.63400242, 61.5212612 , 72.75890845, 57.61661024, 65.39866234, 47.41540043, 77.30090306, 50.89748999, 61.50746927, 79.49715321, 41.20356734, 56.03800925, 94.15462125, 63.20338365, 95.28639722, 75.81232017, 32.09039495, 53.32924083, 59.63109694, 87.66349323, 71.51329672, 42.87189277, 48.26400166, 60.3417937 , 42.20664573, 43.00066797, 55.18067154, 76.27390707, 78.82137727, 85.451 , 69.72150089, 55.34246911, 80.58014985, 86.81124055, 83.98661409, 74.85431944, 42.61049311, 38.15734865, 85.9956654 , 43.26669312, 78.97732552, 36.26190476, 41.72199711, 53.07704693, 34.18176444, 53.47318892, 48.90472817, 58.57496796, 92.79668924, 78.4475137 , 53.60581926, 32.80670145, 49.42504221, 57.3471786 , 75.856811 , 35.04339654, 52.5227308 , 34.48859646, 59.79505366, 64.33654553, 54.4064929 , 24.49359101, 57.92275495, 57.89097697, 100.67209492, 88.89994674, 39.90751776, 58.73859727, 76.14315942, 62.97781913, 71.07756578, 97.56604108, 99.52729537, 44.67429365, 88.03345238, 68.59522164, 98.84517573, 60.62938004, 87.14872698, 47.58679546, 72.82187091, 65.94254085, 89.09710914, 61.96470938, 69.36748378, 60.7044839 , 81.60740097, 71.04032612, 54.72223377, 32.06326537, 67.85120587, 68.54966934, 90.59010516, 53.08508131, 114.62275726, 81.56617735, 116.21760762, 100.6281484 , 107.36454736, 54.93407498, 61.55746122, 74.83490502, 67.25198763, 75.9885209 ])
mse_rf_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_rf)
mse_rf_train
356.848877962145
# import math
rmse_rf_train = math.sqrt(mse_rf_train)
rmse_rf_train
18.89044409118391
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(train_y, train_y_pred_rf)
Regression statistics Mean Error (ME) : -0.3851 Root Mean Squared Error (RMSE) : 18.8904 Mean Absolute Error (MAE) : 15.6621 Mean Percentage Error (MPE) : -38.5278 Mean Absolute Percentage Error (MAPE) : 52.8393
valid_y_pred_rf = rf.predict(valid_X)
valid_y_pred_rf
array([85.99613623, 72.81197421, 57.14756602, 65.48415094, 64.23185644, 73.88297094, 75.63706658, 65.88536111, 77.38755932, 53.3611746 , 84.30540558, 81.38338597, 69.54557467, 76.4453801 , 77.30100974, 60.7457147 , 60.28698095, 54.54960764, 58.39220608, 79.51837662, 57.45235255, 44.49966493, 80.43081232, 80.9672802 , 79.12526316, 61.88250468, 68.48486885, 55.05741665, 54.63041239, 66.02297654, 74.05684091, 61.26967588, 63.94936452, 68.44916792, 79.89559722, 47.99857592, 77.75612193, 51.24598032, 91.96242819, 95.47827337, 72.99748216, 61.48166667, 74.24409524, 71.74566751, 65.74643356, 72.12616881, 59.34396605, 55.31680735, 75.01714982, 67.35775287, 62.03485455, 69.73391667, 60.42239718, 49.83764828, 78.65800622, 72.81000229, 84.1797256 , 73.57869858, 69.42034164, 54.17249117, 80.91616771, 31.0329006 , 58.56498167, 86.61381342, 53.14061718, 46.71116392, 51.25840051, 57.74422334, 70.76289968, 68.26478858, 53.12914791, 67.50639943, 58.15775315, 86.5268843 , 54.6770392 , 67.69249941, 59.79266884, 74.1011205 , 80.20981113, 47.86735296, 57.1792636 , 65.4681474 , 71.61567756, 64.21365135, 52.96507108, 49.00585282, 60.93468812, 66.42856805, 72.62375 , 61.39575291, 52.3028228 , 70.51631313, 61.68713413, 82.59035772, 62.97275458, 65.08313092, 62.98707668, 54.98747072, 74.79348167, 77.76127609, 53.28477675, 65.1869692 , 68.7451453 , 70.07741324, 62.38078571, 63.3365805 , 73.22574008, 64.41808607, 59.03357159, 56.65995647, 48.07595775, 66.59146202, 66.1300717 , 64.77757169, 58.36584383, 63.15126245, 53.69026718, 61.60511499, 68.70015011, 68.68720783, 61.64060036, 70.09213449, 73.00846138, 63.50531364, 54.47829147, 51.04458733, 75.76759166, 50.68129187, 85.91485798, 45.009602 , 74.03758542, 60.60588345, 66.72822863, 50.40984565, 50.91525236, 79.34774785, 74.27093714, 82.57503454, 84.23115898, 58.34441943, 73.50657518, 76.91086606, 94.43031653, 51.995 , 66.51659984, 59.95371429, 48.8428732 , 74.67547684, 54.12539596, 62.59740704, 38.82594481, 66.68127848, 59.69954936, 62.24238866, 51.24587194, 67.46654537, 50.08880019, 78.89875988, 70.45242092, 76.13153563, 47.54842012, 71.55815842, 77.32311666, 51.02301605, 45.95923124, 54.17618857, 72.52685606, 59.74554544, 61.51544039, 61.22026958, 74.06011706, 85.25158929, 66.50520371, 71.80522197, 46.78741758, 75.15755218, 64.5374446 , 56.70808703, 83.22767766, 59.33948239, 48.14300962, 80.07204362, 65.39221338, 61.34071397, 50.61837513, 76.18111707, 48.45899522, 82.61167027, 68.64237587, 69.85818255, 66.21509905, 82.97757268, 45.1973547 , 61.34981061, 69.38518179, 91.70802321, 70.23158458, 27.83719221, 59.99675521, 69.57862161, 70.99210766, 60.44286601, 65.59296299, 70.98130797, 65.47639993, 68.08674206, 68.73752076, 67.61220989, 74.31800031, 64.41745114, 67.19080289, 78.22315147, 76.38173478, 56.18366362, 52.71594479, 56.61547222, 73.43090887, 91.90809606, 59.87043638, 57.65063381, 51.46242433, 78.90791805, 60.54316234, 51.72708547, 75.8761541 , 71.22114286, 43.23641919, 65.66313809, 94.86817947, 67.21091462, 52.21235575, 72.71997829, 67.16296753, 60.14175456, 68.69688344, 49.47745978, 61.98980765, 61.37528073, 55.09736371, 66.38018356, 56.80601079, 64.55983748, 79.77329875, 49.94612653, 56.46429733, 70.81355007, 73.30237874, 69.29287879, 51.04525876, 57.27526319, 55.77392348, 70.01296176, 74.8949364 , 83.63937886, 62.96207777, 53.7513566 , 62.29764406, 69.50477559, 43.34733164, 80.23455717, 61.96857608, 69.37674323, 37.29 , 65.15223851, 72.95181574, 50.16673751, 75.97699246, 70.30796224, 64.95026118, 68.25145774, 61.41429168, 65.10816719, 64.88187943, 68.41595238, 57.63969008, 65.51860532, 88.77396387, 70.2144709 , 60.30475757, 55.3897939 , 61.63742355, 94.91610892, 72.07392157, 60.4653631 , 55.745187 , 73.56503645, 42.8781292 , 69.60228571, 84.44706084, 48.06643892, 71.47878175, 57.75287869, 66.18710983, 83.19506918])
mse_rf_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_rf)
mse_rf_valid
1426.6006551155326
# import math
rmse_rf_valid = math.sqrt(mse_rf_valid)
rmse_rf_valid
37.7703674209761
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(valid_y, valid_y_pred_rf)
Regression statistics Mean Error (ME) : 3.4807 Root Mean Squared Error (RMSE) : 37.7704 Mean Absolute Error (MAE) : 31.1829 Mean Percentage Error (MPE) : -48.5668 Mean Absolute Percentage Error (MAPE) : 81.0687
new_records_dnd_rf_pred = rf.predict(new_dnd_df)
new_records_dnd_rf_pred
array([64.32568489, 70.02311429])
dnd_rf_prediction_df = pd.DataFrame(new_records_dnd_rf_pred,
columns = ["Prediction"])
dnd_rf_prediction_df
Prediction | |
---|---|
0 | 64.325685 |
1 | 70.023114 |
Combine with new data set
new_dnd_df_with_prediction_rf = pd.concat((new_dnd_df, dnd_rf_prediction_df), axis = 1)
new_dnd_df_with_prediction_rf
# to export
# new_dnd_df_with_prediction_rf.to_csv("whatever_name.csv")
STR | DEX | CON | INT | WIS | CHA | Prediction | |
---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 64.325685 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 70.023114 |