import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeRegressor
dnd_df = pd.read_csv("super_heroes_dnd_v3a.csv")
dnd_df.head()
ID | Name | Gender | Race | Height | Publisher | Alignment | Weight | STR | DEX | CON | INT | WIS | CHA | Level | HP | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | A001 | A-Bomb | Male | Human | 203.0 | Marvel Comics | good | 441.0 | 18 | 11 | 17 | 12 | 13 | 11 | 1 | 7 |
1 | A002 | Abe Sapien | Male | Icthyo Sapien | 191.0 | Dark Horse Comics | good | 65.0 | 16 | 17 | 10 | 13 | 15 | 11 | 8 | 72 |
2 | A004 | Abomination | Male | Human / Radiation | 203.0 | Marvel Comics | bad | 441.0 | 13 | 14 | 13 | 10 | 18 | 15 | 15 | 135 |
3 | A009 | Agent 13 | Female | NaN | 173.0 | Marvel Comics | good | 61.0 | 15 | 18 | 16 | 16 | 17 | 10 | 14 | 140 |
4 | A015 | Alex Mercer | Male | Human | NaN | Wildstorm | bad | NaN | 14 | 17 | 13 | 12 | 10 | 11 | 9 | 72 |
dnd_df.dtypes
ID object Name object Gender object Race object Height float64 Publisher object Alignment object Weight float64 STR int64 DEX int64 CON int64 INT int64 WIS int64 CHA int64 Level int64 HP int64 dtype: object
It's a good idea to get a sense of the target variable
dnd_df["HP"].describe()
count 734.000000 mean 66.885559 std 36.653877 min 6.000000 25% 36.000000 50% 63.000000 75% 91.000000 max 150.000000 Name: HP, dtype: float64
pd.DataFrame(dnd_df.columns.values, columns = ["variables"])
variables | |
---|---|
0 | ID |
1 | Name |
2 | Gender |
3 | Race |
4 | Height |
5 | Publisher |
6 | Alignment |
7 | Weight |
8 | STR |
9 | DEX |
10 | CON |
11 | INT |
12 | WIS |
13 | CHA |
14 | Level |
15 | HP |
dnd_df_2 = dnd_df.iloc[:, np.r_[8:14, 15]]
dnd_df_2
# Alternatively, use:
# dnd_df.iloc[:, list(range(8,14)) + [15]]
# Note the end range
# Or just use:
# dnd_df.iloc[:, [8, 9, 10, 11, 12, 13, 15]]
# Or use the variable name range
# dnd_df.loc[:, "STR":"HP"]
# Or specify the variable names
# dnd_df.loc[:, ["STR", "DEX", "CON", "INT", "WIS", "CHA", "HP"]]
STR | DEX | CON | INT | WIS | CHA | HP | |
---|---|---|---|---|---|---|---|
0 | 18 | 11 | 17 | 12 | 13 | 11 | 7 |
1 | 16 | 17 | 10 | 13 | 15 | 11 | 72 |
2 | 13 | 14 | 13 | 10 | 18 | 15 | 135 |
3 | 15 | 18 | 16 | 16 | 17 | 10 | 140 |
4 | 14 | 17 | 13 | 12 | 10 | 11 | 72 |
... | ... | ... | ... | ... | ... | ... | ... |
729 | 8 | 14 | 17 | 13 | 14 | 15 | 64 |
730 | 17 | 12 | 11 | 11 | 14 | 10 | 56 |
731 | 18 | 10 | 14 | 17 | 10 | 10 | 49 |
732 | 11 | 11 | 10 | 12 | 15 | 16 | 36 |
733 | 16 | 12 | 18 | 15 | 15 | 16 | 81 |
734 rows × 7 columns
import sklearn
from sklearn.model_selection import train_test_split
predictors = ["STR", "DEX", "CON", "INT", "WIS", "CHA"]
outcome = "HP"
X = dnd_df_2.drop(columns = ["HP"])
y = dnd_df_2["HP"]
train_X, valid_X, train_y, valid_y = train_test_split(X, y, test_size = 0.4, random_state = 666)
train_X.head()
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
650 | 17 | 14 | 16 | 16 | 15 | 17 |
479 | 8 | 18 | 16 | 10 | 14 | 17 |
271 | 9 | 12 | 17 | 10 | 15 | 17 |
647 | 9 | 18 | 16 | 10 | 17 | 13 |
307 | 12 | 16 | 14 | 18 | 15 | 13 |
len(train_X)
440
train_y.head()
650 117 479 120 271 72 647 117 307 100 Name: HP, dtype: int64
len(train_y)
440
valid_X.head()
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
389 | 10 | 16 | 15 | 13 | 11 | 10 |
131 | 18 | 10 | 12 | 10 | 16 | 18 |
657 | 10 | 11 | 12 | 11 | 18 | 14 |
421 | 16 | 13 | 11 | 16 | 13 | 11 |
160 | 12 | 16 | 17 | 18 | 11 | 15 |
len(valid_X)
294
valid_y.head()
389 45 131 42 657 63 421 64 160 54 Name: HP, dtype: int64
len(valid_y)
294
full_tree = DecisionTreeRegressor(random_state = 666)
full_tree
DecisionTreeRegressor(random_state=666)
full_tree_fit = full_tree.fit(train_X, train_y)
Plot the tree
from sklearn import tree
Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.
text_representation = tree.export_text(full_tree, max_depth = 5)
print(text_representation)
|--- feature_1 <= 10.50 | |--- feature_3 <= 14.50 | | |--- feature_3 <= 10.50 | | | |--- feature_5 <= 16.00 | | | | |--- feature_2 <= 13.50 | | | | | |--- value: [18.00] | | | | |--- feature_2 > 13.50 | | | | | |--- feature_0 <= 11.50 | | | | | | |--- value: [48.00] | | | | | |--- feature_0 > 11.50 | | | | | | |--- value: [50.00] | | | |--- feature_5 > 16.00 | | | | |--- value: [9.00] | | |--- feature_3 > 10.50 | | | |--- feature_2 <= 13.50 | | | | |--- feature_0 <= 17.50 | | | | | |--- feature_0 <= 12.00 | | | | | | |--- value: [40.00] | | | | | |--- feature_0 > 12.00 | | | | | | |--- truncated branch of depth 4 | | | | |--- feature_0 > 17.50 | | | | | |--- value: [90.00] | | | |--- feature_2 > 13.50 | | | | |--- feature_0 <= 10.50 | | | | | |--- feature_3 <= 13.50 | | | | | | |--- value: [56.00] | | | | | |--- feature_3 > 13.50 | | | | | | |--- value: [50.00] | | | | |--- feature_0 > 10.50 | | | | | |--- feature_2 <= 17.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_2 > 17.50 | | | | | | |--- value: [120.00] | |--- feature_3 > 14.50 | | |--- feature_5 <= 10.50 | | | |--- feature_0 <= 17.50 | | | | |--- feature_4 <= 16.00 | | | | | |--- value: [112.00] | | | | |--- feature_4 > 16.00 | | | | | |--- value: [84.00] | | | |--- feature_0 > 17.50 | | | | |--- value: [49.00] | | |--- feature_5 > 10.50 | | | |--- feature_4 <= 11.50 | | | | |--- feature_3 <= 16.50 | | | | | |--- feature_4 <= 10.50 | | | | | | |--- value: [6.00] | | | | | |--- feature_4 > 10.50 | | | | | | |--- truncated branch of depth 2 | | | | |--- feature_3 > 16.50 | | | | | |--- feature_4 <= 10.50 | | | | | | |--- value: [54.00] | | | | | |--- feature_4 > 10.50 | | | | | | |--- value: [20.00] | | | |--- feature_4 > 11.50 | | | | |--- feature_4 <= 17.50 | | | | | |--- feature_5 <= 12.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_5 > 12.50 | | | | | | |--- truncated branch of depth 6 | | | | |--- feature_4 > 17.50 | | | | | |--- feature_3 <= 17.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_3 > 17.50 | | | | | | |--- value: [50.00] |--- feature_1 > 10.50 | |--- feature_4 <= 17.50 | | |--- feature_2 <= 17.50 | | | |--- feature_5 <= 10.50 | | | | |--- feature_2 <= 12.50 | | | | | |--- feature_4 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_4 > 11.50 | | | | | | |--- truncated branch of depth 5 | | | | |--- feature_2 > 12.50 | | | | | |--- feature_2 <= 16.50 | | | | | | |--- truncated branch of depth 7 | | | | | |--- feature_2 > 16.50 | | | | | | |--- truncated branch of depth 5 | | | |--- feature_5 > 10.50 | | | | |--- feature_5 <= 17.50 | | | | | |--- feature_3 <= 10.50 | | | | | | |--- truncated branch of depth 10 | | | | | |--- feature_3 > 10.50 | | | | | | |--- truncated branch of depth 13 | | | | |--- feature_5 > 17.50 | | | | | |--- feature_2 <= 15.50 | | | | | | |--- truncated branch of depth 10 | | | | | |--- feature_2 > 15.50 | | | | | | |--- truncated branch of depth 8 | | |--- feature_2 > 17.50 | | | |--- feature_1 <= 15.50 | | | | |--- feature_4 <= 12.50 | | | | | |--- feature_0 <= 16.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_0 > 16.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_4 > 12.50 | | | | | |--- feature_0 <= 17.50 | | | | | | |--- truncated branch of depth 5 | | | | | |--- feature_0 > 17.50 | | | | | | |--- value: [8.00] | | | |--- feature_1 > 15.50 | | | | |--- feature_4 <= 12.50 | | | | | |--- feature_3 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_3 > 11.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_4 > 12.50 | | | | | |--- feature_1 <= 16.50 | | | | | | |--- truncated branch of depth 4 | | | | | |--- feature_1 > 16.50 | | | | | | |--- truncated branch of depth 4 | |--- feature_4 > 17.50 | | |--- feature_0 <= 14.50 | | | |--- feature_3 <= 16.50 | | | | |--- feature_3 <= 13.50 | | | | | |--- feature_1 <= 14.50 | | | | | | |--- truncated branch of depth 6 | | | | | |--- feature_1 > 14.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_3 > 13.50 | | | | | |--- feature_1 <= 11.50 | | | | | | |--- truncated branch of depth 2 | | | | | |--- feature_1 > 11.50 | | | | | | |--- truncated branch of depth 5 | | | |--- feature_3 > 16.50 | | | | |--- feature_2 <= 17.00 | | | | | |--- feature_3 <= 17.50 | | | | | | |--- value: [72.00] | | | | | |--- feature_3 > 17.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_2 > 17.00 | | | | | |--- value: [117.00] | | |--- feature_0 > 14.50 | | | |--- feature_0 <= 15.50 | | | | |--- feature_1 <= 14.50 | | | | | |--- feature_5 <= 16.00 | | | | | | |--- value: [9.00] | | | | | |--- feature_5 > 16.00 | | | | | | |--- value: [6.00] | | | | |--- feature_1 > 14.50 | | | | | |--- value: [28.00] | | | |--- feature_0 > 15.50 | | | | |--- feature_5 <= 17.50 | | | | | |--- feature_2 <= 12.50 | | | | | | |--- truncated branch of depth 3 | | | | | |--- feature_2 > 12.50 | | | | | | |--- truncated branch of depth 3 | | | | |--- feature_5 > 17.50 | | | | | |--- value: [112.00]
Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.
tree.plot_tree(full_tree, feature_names = train_X.columns, max_depth = 5)
[Text(0.45454545454545453, 0.9285714285714286, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'), Text(0.1690340909090909, 0.7857142857142857, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'), Text(0.07670454545454546, 0.6428571428571429, 'INT <= 10.5\nsquared_error = 742.63\nsamples = 21\nvalue = 64.476'), Text(0.03409090909090909, 0.5, 'CHA <= 16.0\nsquared_error = 325.688\nsamples = 4\nvalue = 31.25'), Text(0.022727272727272728, 0.35714285714285715, 'CON <= 13.5\nsquared_error = 214.222\nsamples = 3\nvalue = 38.667'), Text(0.011363636363636364, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 18.0'), Text(0.03409090909090909, 0.21428571428571427, 'STR <= 11.5\nsquared_error = 1.0\nsamples = 2\nvalue = 49.0'), Text(0.022727272727272728, 0.07142857142857142, '\n (...) \n'), Text(0.045454545454545456, 0.07142857142857142, '\n (...) \n'), Text(0.045454545454545456, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 9.0'), Text(0.11931818181818182, 0.5, 'CON <= 13.5\nsquared_error = 519.855\nsamples = 17\nvalue = 72.294'), Text(0.09090909090909091, 0.35714285714285715, 'STR <= 17.5\nsquared_error = 187.484\nsamples = 8\nvalue = 58.375'), Text(0.07954545454545454, 0.21428571428571427, 'STR <= 12.0\nsquared_error = 50.98\nsamples = 7\nvalue = 53.857'), Text(0.06818181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.09090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.10227272727272728, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 90.0'), Text(0.14772727272727273, 0.35714285714285715, 'STR <= 10.5\nsquared_error = 490.0\nsamples = 9\nvalue = 84.667'), Text(0.125, 0.21428571428571427, 'INT <= 13.5\nsquared_error = 9.0\nsamples = 2\nvalue = 53.0'), Text(0.11363636363636363, 0.07142857142857142, '\n (...) \n'), Text(0.13636363636363635, 0.07142857142857142, '\n (...) \n'), Text(0.17045454545454544, 0.21428571428571427, 'CON <= 17.5\nsquared_error = 259.061\nsamples = 7\nvalue = 93.714'), Text(0.1590909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.18181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.26136363636363635, 0.6428571428571429, 'CHA <= 10.5\nsquared_error = 824.81\nsamples = 22\nvalue = 40.091'), Text(0.2159090909090909, 0.5, 'STR <= 17.5\nsquared_error = 664.222\nsamples = 3\nvalue = 81.667'), Text(0.20454545454545456, 0.35714285714285715, 'WIS <= 16.0\nsquared_error = 196.0\nsamples = 2\nvalue = 98.0'), Text(0.19318181818181818, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0'), Text(0.2159090909090909, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 84.0'), Text(0.22727272727272727, 0.35714285714285715, 'squared_error = 0.0\nsamples = 1\nvalue = 49.0'), Text(0.3068181818181818, 0.5, 'WIS <= 11.5\nsquared_error = 534.144\nsamples = 19\nvalue = 33.526'), Text(0.26136363636363635, 0.35714285714285715, 'INT <= 16.5\nsquared_error = 319.04\nsamples = 5\nvalue = 19.6'), Text(0.23863636363636365, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 2.667\nsamples = 3\nvalue = 8.0'), Text(0.22727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.25, 0.07142857142857142, '\n (...) \n'), Text(0.2840909090909091, 0.21428571428571427, 'WIS <= 10.5\nsquared_error = 289.0\nsamples = 2\nvalue = 37.0'), Text(0.2727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.29545454545454547, 0.07142857142857142, '\n (...) \n'), Text(0.3522727272727273, 0.35714285714285715, 'WIS <= 17.5\nsquared_error = 516.964\nsamples = 14\nvalue = 38.5'), Text(0.32954545454545453, 0.21428571428571427, 'CHA <= 12.5\nsquared_error = 382.41\nsamples = 10\nvalue = 46.3'), Text(0.3181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.3409090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.375, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 321.0\nsamples = 4\nvalue = 19.0'), Text(0.36363636363636365, 0.07142857142857142, '\n (...) \n'), Text(0.38636363636363635, 0.07142857142857142, '\n (...) \n'), Text(0.7400568181818182, 0.7857142857142857, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'), Text(0.5795454545454546, 0.6428571428571429, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'), Text(0.48863636363636365, 0.5, 'CHA <= 10.5\nsquared_error = 1409.398\nsamples = 323\nvalue = 69.372'), Text(0.4431818181818182, 0.35714285714285715, 'CON <= 12.5\nsquared_error = 1278.057\nsamples = 43\nvalue = 78.419'), Text(0.42045454545454547, 0.21428571428571427, 'WIS <= 11.5\nsquared_error = 1262.102\nsamples = 14\nvalue = 57.429'), Text(0.4090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.4318181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.4659090909090909, 0.21428571428571427, 'CON <= 16.5\nsquared_error = 970.385\nsamples = 29\nvalue = 88.552'), Text(0.45454545454545453, 0.07142857142857142, '\n (...) \n'), Text(0.4772727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.5340909090909091, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 1415.068\nsamples = 280\nvalue = 67.982'), Text(0.5113636363636364, 0.21428571428571427, 'INT <= 10.5\nsquared_error = 1402.723\nsamples = 241\nvalue = 66.593'), Text(0.5, 0.07142857142857142, '\n (...) \n'), Text(0.5227272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.5568181818181818, 0.21428571428571427, 'CON <= 15.5\nsquared_error = 1405.784\nsamples = 39\nvalue = 76.564'), Text(0.5454545454545454, 0.07142857142857142, '\n (...) \n'), Text(0.5681818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.6704545454545454, 0.5, 'DEX <= 15.5\nsquared_error = 1251.268\nsamples = 36\nvalue = 56.806'), Text(0.625, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 779.741\nsamples = 21\nvalue = 45.857'), Text(0.6022727272727273, 0.21428571428571427, 'STR <= 16.5\nsquared_error = 788.29\nsamples = 10\nvalue = 54.9'), Text(0.5909090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.6136363636363636, 0.07142857142857142, '\n (...) \n'), Text(0.6477272727272727, 0.21428571428571427, 'STR <= 17.5\nsquared_error = 630.05\nsamples = 11\nvalue = 37.636'), Text(0.6363636363636364, 0.07142857142857142, '\n (...) \n'), Text(0.6590909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.7159090909090909, 0.35714285714285715, 'WIS <= 12.5\nsquared_error = 1508.649\nsamples = 15\nvalue = 72.133'), Text(0.6931818181818182, 0.21428571428571427, 'INT <= 11.5\nsquared_error = 584.889\nsamples = 6\nvalue = 37.333'), Text(0.6818181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.7045454545454546, 0.07142857142857142, '\n (...) \n'), Text(0.7386363636363636, 0.21428571428571427, 'DEX <= 16.5\nsquared_error = 778.889\nsamples = 9\nvalue = 95.333'), Text(0.7272727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.75, 0.07142857142857142, '\n (...) \n'), Text(0.9005681818181818, 0.6428571428571429, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'), Text(0.8465909090909091, 0.5, 'INT <= 16.5\nsquared_error = 1403.386\nsamples = 26\nvalue = 61.808'), Text(0.8068181818181818, 0.35714285714285715, 'INT <= 13.5\nsquared_error = 1391.959\nsamples = 21\nvalue = 54.429'), Text(0.7840909090909091, 0.21428571428571427, 'DEX <= 14.5\nsquared_error = 1660.628\nsamples = 11\nvalue = 67.909'), Text(0.7727272727272727, 0.07142857142857142, '\n (...) \n'), Text(0.7954545454545454, 0.07142857142857142, '\n (...) \n'), Text(0.8295454545454546, 0.21428571428571427, 'DEX <= 11.5\nsquared_error = 676.64\nsamples = 10\nvalue = 39.6'), Text(0.8181818181818182, 0.07142857142857142, '\n (...) \n'), Text(0.8409090909090909, 0.07142857142857142, '\n (...) \n'), Text(0.8863636363636364, 0.35714285714285715, 'CON <= 17.0\nsquared_error = 262.16\nsamples = 5\nvalue = 92.8'), Text(0.875, 0.21428571428571427, 'INT <= 17.5\nsquared_error = 144.688\nsamples = 4\nvalue = 86.75'), Text(0.8636363636363636, 0.07142857142857142, '\n (...) \n'), Text(0.8863636363636364, 0.07142857142857142, '\n (...) \n'), Text(0.8977272727272727, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 117.0'), Text(0.9545454545454546, 0.5, 'STR <= 15.5\nsquared_error = 892.889\nsamples = 12\nvalue = 45.667'), Text(0.9318181818181818, 0.35714285714285715, 'DEX <= 14.5\nsquared_error = 76.5\nsamples = 4\nvalue = 13.0'), Text(0.9204545454545454, 0.21428571428571427, 'CHA <= 16.0\nsquared_error = 2.0\nsamples = 3\nvalue = 8.0'), Text(0.9090909090909091, 0.07142857142857142, '\n (...) \n'), Text(0.9318181818181818, 0.07142857142857142, '\n (...) \n'), Text(0.9431818181818182, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 28.0'), Text(0.9772727272727273, 0.35714285714285715, 'CHA <= 17.5\nsquared_error = 500.75\nsamples = 8\nvalue = 62.0'), Text(0.9659090909090909, 0.21428571428571427, 'CON <= 12.5\nsquared_error = 164.122\nsamples = 7\nvalue = 54.857'), Text(0.9545454545454546, 0.07142857142857142, '\n (...) \n'), Text(0.9772727272727273, 0.07142857142857142, '\n (...) \n'), Text(0.9886363636363636, 0.21428571428571427, 'squared_error = 0.0\nsamples = 1\nvalue = 112.0')]
Export tree and convert to a picture file.
from sklearn.tree import export_graphviz
dot_data = export_graphviz(full_tree, out_file='full_tree.dot', feature_names = train_X.columns)
Not very useful.
small_tree = DecisionTreeRegressor(random_state = 666, max_depth = 3, min_samples_split = 25)
small_tree
DecisionTreeRegressor(max_depth=3, min_samples_split=25, random_state=666)
small_tree_fit = small_tree.fit(train_X, train_y)
Plot the tree
# For illustration:
# from sklearn import tree
Export the top levels for illustration using max_depth. Export the whole tree if max_depth is excluded.
text_representation_2 = tree.export_text(small_tree)
print(text_representation_2)
|--- feature_1 <= 10.50 | |--- feature_3 <= 14.50 | | |--- value: [64.48] | |--- feature_3 > 14.50 | | |--- value: [40.09] |--- feature_1 > 10.50 | |--- feature_4 <= 17.50 | | |--- feature_2 <= 17.50 | | | |--- value: [69.37] | | |--- feature_2 > 17.50 | | | |--- value: [56.81] | |--- feature_4 > 17.50 | | |--- feature_0 <= 14.50 | | | |--- value: [61.81] | | |--- feature_0 > 14.50 | | | |--- value: [45.67]
Plot the top 5 levels for illustration using max_depth. Plot the whole tree if max_depth is excluded.
tree.plot_tree(small_tree, feature_names = train_X.columns)
[Text(0.4090909090909091, 0.875, 'DEX <= 10.5\nsquared_error = 1382.015\nsamples = 440\nvalue = 65.552'), Text(0.18181818181818182, 0.625, 'INT <= 14.5\nsquared_error = 933.256\nsamples = 43\nvalue = 52.0'), Text(0.09090909090909091, 0.375, 'squared_error = 742.63\nsamples = 21\nvalue = 64.476'), Text(0.2727272727272727, 0.375, 'squared_error = 824.81\nsamples = 22\nvalue = 40.091'), Text(0.6363636363636364, 0.625, 'WIS <= 17.5\nsquared_error = 1408.574\nsamples = 397\nvalue = 67.02'), Text(0.45454545454545453, 0.375, 'CON <= 17.5\nsquared_error = 1407.787\nsamples = 359\nvalue = 68.111'), Text(0.36363636363636365, 0.125, 'squared_error = 1409.398\nsamples = 323\nvalue = 69.372'), Text(0.5454545454545454, 0.125, 'squared_error = 1251.268\nsamples = 36\nvalue = 56.806'), Text(0.8181818181818182, 0.375, 'STR <= 14.5\nsquared_error = 1298.469\nsamples = 38\nvalue = 56.711'), Text(0.7272727272727273, 0.125, 'squared_error = 1403.386\nsamples = 26\nvalue = 61.808'), Text(0.9090909090909091, 0.125, 'squared_error = 892.889\nsamples = 12\nvalue = 45.667')]
Export tree and convert to a picture file.
# For illustration
# from sklearn.tree import export_graphviz
dot_data_2 = export_graphviz(small_tree, out_file='small_tree.dot', feature_names = train_X.columns)
Much better.
train_y_pred_full = full_tree.predict(train_X)
train_y_pred_full
array([117., 120., 72., 117., 100., 90., 49., 60., 42., 50., 20., 32., 36., 80., 36., 30., 117., 99., 90., 12., 42., 28., 40., 24., 50., 98., 32., 90., 50., 88., 130., 35., 18., 70., 99., 66., 20., 150., 56., 8., 98., 54., 81., 35., 6., 81., 56., 90., 88., 70., 60., 104., 24., 81., 54., 60., 30., 80., 84., 98., 12., 140., 135., 56., 135., 30., 117., 99., 81., 105., 42., 48., 100., 110., 77., 84., 84., 104., 16., 64., 48., 16., 84., 18., 48., 20., 24., 54., 9., 99., 56., 140., 72., 20., 112., 8., 110., 120., 35., 63., 21., 99., 36., 72., 16., 77., 150., 50., 90., 78., 60., 81., 104., 45., 56., 7., 10., 60., 56., 96., 72., 28., 40., 72., 78., 18., 54., 110., 8., 16., 84., 130., 88., 90., 54., 100., 110., 72., 90., 81., 8., 72., 30., 140., 126., 105., 36., 18., 140., 30., 32., 18., 66., 63., 24., 78., 21., 16., 32., 9., 28., 130., 42., 70., 105., 56., 135., 63., 45., 72., 72., 6., 104., 64., 96., 90., 20., 84., 7., 90., 63., 42., 60., 72., 49., 9., 6., 90., 130., 90., 90., 42., 35., 9., 45., 40., 108., 21., 108., 30., 84., 112., 135., 112., 112., 18., 84., 50., 40., 9., 99., 81., 72., 72., 30., 60., 96., 27., 140., 60., 90., 72., 42., 72., 81., 80., 117., 32., 135., 8., 36., 63., 80., 16., 120., 72., 100., 110., 48., 42., 64., 130., 48., 90., 84., 54., 48., 54., 18., 80., 49., 84., 150., 78., 126., 63., 9., 16., 50., 120., 8., 32., 56., 135., 16., 77., 24., 60., 48., 18., 8., 70., 63., 54., 91., 80., 112., 70., 120., 120., 120., 8., 56., 12., 88., 28., 18., 81., 48., 91., 117., 42., 49., 140., 28., 120., 56., 110., 130., 72., 18., 77., 126., 32., 42., 36., 16., 9., 88., 54., 72., 30., 126., 88., 84., 24., 60., 117., 104., 120., 77., 105., 42., 110., 88., 56., 35., 42., 80., 30., 50., 48., 24., 21., 56., 72., 9., 63., 98., 60., 48., 16., 117., 30., 70., 104., 49., 21., 130., 56., 117., 78., 8., 36., 48., 91., 84., 24., 36., 72., 10., 18., 36., 80., 90., 112., 63., 32., 96., 72., 108., 80., 9., 18., 98., 18., 88., 20., 18., 30., 12., 54., 36., 42., 120., 70., 32., 7., 40., 63., 77., 28., 24., 7., 36., 48., 54., 10., 56., 42., 135., 98., 10., 54., 84., 54., 60., 117., 135., 35., 117., 72., 130., 63., 110., 21., 81., 48., 110., 54., 60., 49., 91., 72., 48., 10., 77., 72., 112., 45., 150., 88., 150., 135., 140., 32., 70., 80., 72., 91.])
Get the RMSE for the training set
mse_full_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_full)
mse_full_tree_train
0.0
import math
rmse_full_tree_train = math.sqrt(mse_full_tree_train)
rmse_full_tree_train
0.0
If using the dmba package, install it first:
pip install dmba or conda install -c conda-forge dmba
import dmba from dmba import regressionSummary
import dmba
from dmba import regressionSummary
regressionSummary(train_y, train_y_pred_full)
Regression statistics Mean Error (ME) : 0.0000 Root Mean Squared Error (RMSE) : 0.0000 Mean Absolute Error (MAE) : 0.0000 Mean Percentage Error (MPE) : 0.0000 Mean Absolute Percentage Error (MAPE) : 0.0000
On the validation set
valid_y_pred_full = full_tree.predict(valid_X)
valid_y_pred_full
array([110., 9., 70., 104., 84., 36., 150., 35., 90., 20., 32., 130., 117., 96., 120., 21., 70., 35., 18., 54., 64., 27., 32., 72., 126., 88., 30., 20., 40., 9., 126., 112., 21., 54., 110., 9., 21., 70., 140., 110., 72., 36., 117., 105., 72., 18., 8., 54., 81., 40., 36., 135., 90., 112., 72., 91., 110., 63., 100., 27., 110., 7., 30., 90., 40., 16., 32., 104., 48., 90., 12., 140., 36., 135., 16., 18., 60., 96., 84., 54., 30., 42., 80., 28., 30., 16., 130., 8., 90., 24., 40., 117., 49., 140., 20., 42., 72., 90., 108., 42., 60., 80., 135., 48., 42., 84., 91., 72., 84., 36., 50., 50., 56., 126., 28., 42., 80., 36., 96., 36., 8., 72., 10., 72., 77., 10., 135., 54., 140., 16., 140., 40., 9., 63., 8., 54., 150., 63., 56., 77., 84., 140., 135., 9., 84., 36., 8., 91., 16., 42., 8., 110., 56., 36., 20., 18., 63., 117., 81., 98., 81., 84., 120., 36., 8., 27., 130., 12., 63., 9., 30., 84., 130., 16., 56., 90., 21., 63., 32., 40., 40., 140., 60., 84., 18., 100., 18., 126., 36., 150., 80., 77., 16., 56., 88., 110., 18., 8., 18., 96., 42., 30., 21., 110., 54., 60., 42., 135., 126., 130., 72., 117., 72., 21., 90., 70., 48., 117., 24., 60., 45., 91., 8., 6., 130., 81., 50., 66., 96., 130., 63., 63., 8., 78., 70., 63., 9., 72., 18., 80., 54., 72., 56., 40., 84., 60., 70., 49., 64., 30., 54., 28., 72., 104., 78., 18., 56., 54., 6., 77., 54., 120., 56., 54., 45., 56., 126., 63., 72., 80., 50., 63., 117., 63., 112., 90., 117., 36., 42., 90., 110., 135., 56., 60., 20., 56., 81., 84., 150., 9., 120., 42., 72., 21.])
Get the RMSE for the validation set
mse_full_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_full)
mse_full_tree_valid
2643.622448979592
import math
rmse_full_tree_valid = math.sqrt(mse_full_tree_valid)
rmse_full_tree_valid
51.41616913947977
If using the dmba package, install it first:
pip install dmba or conda install -c conda-forge dmba
import dmba from dmba import regressionSummary
regressionSummary(valid_y, valid_y_pred_full)
Regression statistics Mean Error (ME) : 3.7041 Root Mean Squared Error (RMSE) : 51.4162 Mean Absolute Error (MAE) : 41.2415 Mean Percentage Error (MPE) : -47.2748 Mean Absolute Percentage Error (MAPE) : 99.1131
On the training set
train_y_pred = small_tree.predict(train_X)
train_y_pred
array([69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 40.09090909, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 56.80555556, 61.80769231, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 64.47619048, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 64.47619048, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 61.80769231, 40.09090909, 69.37151703, 64.47619048, 64.47619048, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 40.09090909, 64.47619048, 69.37151703, 61.80769231, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 45.66666667, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703, 61.80769231, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 64.47619048, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 56.80555556, 61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 45.66666667, 69.37151703, 61.80769231, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703])
Get the RMSE for the training set
mse_small_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred)
mse_small_tree_train
1320.9654960245605
import math
rmse_small_tree_train = math.sqrt(mse_small_tree_train)
rmse_small_tree_train
36.34508902210256
If using the dmba package, install it first:
pip install dmba or conda install -c conda-forge dmba
import dmba from dmba import regressionSummary
import dmba
from dmba import regressionSummary
regressionSummary(train_y, train_y_pred)
Regression statistics Mean Error (ME) : 0.0000 Root Mean Squared Error (RMSE) : 36.3451 Mean Absolute Error (MAE) : 30.4192 Mean Percentage Error (MPE) : -75.0378 Mean Absolute Percentage Error (MAPE) : 103.4166
On the validation set
valid_y_pred = small_tree.predict(valid_X)
valid_y_pred
array([69.37151703, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 64.47619048, 61.80769231, 56.80555556, 69.37151703, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 56.80555556, 61.80769231, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 64.47619048, 64.47619048, 56.80555556, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 45.66666667, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 64.47619048, 64.47619048, 69.37151703, 45.66666667, 56.80555556, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 61.80769231, 40.09090909, 61.80769231, 69.37151703, 69.37151703, 61.80769231, 56.80555556, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 69.37151703, 61.80769231, 69.37151703, 69.37151703, 64.47619048, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 61.80769231, 45.66666667, 69.37151703, 69.37151703, 56.80555556, 40.09090909, 64.47619048, 69.37151703, 40.09090909, 69.37151703, 45.66666667, 69.37151703, 64.47619048, 64.47619048, 61.80769231, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 56.80555556, 69.37151703, 69.37151703, 40.09090909, 40.09090909, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 69.37151703, 69.37151703, 69.37151703, 45.66666667, 40.09090909, 69.37151703, 45.66666667, 64.47619048, 69.37151703, 56.80555556, 69.37151703])
Get the RMSE for the validation set
mse_small_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred)
mse_small_tree_valid
1353.8427521618253
import math
rmse_small_tree_valid = math.sqrt(mse_small_tree_valid)
rmse_small_tree_valid
36.79460221502368
If using the dmba package, install it first:
pip install dmba or conda install -c conda-forge dmba
import dmba from dmba import regressionSummary
regressionSummary(valid_y, valid_y_pred)
Regression statistics Mean Error (ME) : 4.2030 Root Mean Squared Error (RMSE) : 36.7946 Mean Absolute Error (MAE) : 30.2375 Mean Percentage Error (MPE) : -48.2303 Mean Absolute Percentage Error (MAPE) : 80.0992
from sklearn.model_selection import GridSearchCV
param_grid = {"max_depth": [2, 3, 4, 5],
"min_samples_split": [10, 20, 30],
"min_impurity_decrease": [0, 0.001, 0.002]}
grid_search = GridSearchCV(DecisionTreeRegressor(random_state = 666), param_grid, cv = 10)
grid_search.fit(train_X, train_y)
GridSearchCV(cv=10, estimator=DecisionTreeRegressor(random_state=666), param_grid={'max_depth': [2, 3, 4, 5], 'min_impurity_decrease': [0, 0.001, 0.002], 'min_samples_split': [10, 20, 30]})
print("Initial parameters:", grid_search.best_params_)
Initial parameters: {'max_depth': 2, 'min_impurity_decrease': 0, 'min_samples_split': 10}
grid_search.best_score_
-0.0933578084568057
grid_search.best_params_
{'max_depth': 2, 'min_impurity_decrease': 0, 'min_samples_split': 10}
best_tree = grid_search.best_estimator_
best_tree
DecisionTreeRegressor(max_depth=2, min_impurity_decrease=0, min_samples_split=10, random_state=666)
dot_data_3 = export_graphviz(best_tree, out_file='best_tree.dot', feature_names = train_X.columns)
On the training set
train_y_best_pred = best_tree.predict(train_X)
train_y_best_pred
array([68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 40.09090909, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 64.47619048, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 40.09090909, 64.47619048, 68.11142061, 56.71052632, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 64.47619048, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061])
Get the RMSE for the training set
mse_best_tree_train = sklearn.metrics.mean_squared_error(train_y, train_y_best_pred)
mse_best_tree_train
1337.4509437318577
import math
rmse_best_tree_train = math.sqrt(mse_best_tree_train)
rmse_best_tree_train
36.57117640617892
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(train_y, train_y_best_pred)
Regression statistics Mean Error (ME) : -0.0000 Root Mean Squared Error (RMSE) : 36.5712 Mean Absolute Error (MAE) : 30.6315 Mean Percentage Error (MPE) : -76.2539 Mean Absolute Percentage Error (MAPE) : 104.6671
On the validation set
valid_y_best_pred = best_tree.predict(valid_X)
valid_y_best_pred
array([68.11142061, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 64.47619048, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 56.71052632, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 64.47619048, 64.47619048, 68.11142061, 56.71052632, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 40.09090909, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 64.47619048, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 56.71052632, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 64.47619048, 68.11142061, 40.09090909, 68.11142061, 56.71052632, 68.11142061, 64.47619048, 64.47619048, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 40.09090909, 40.09090909, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 68.11142061, 68.11142061, 68.11142061, 56.71052632, 40.09090909, 68.11142061, 56.71052632, 64.47619048, 68.11142061, 68.11142061, 68.11142061])
Get the RMSE for the validation set
mse_best_tree_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_best_pred)
mse_best_tree_valid
1320.469997098233
import math
rmse_best_tree_valid = math.sqrt(mse_best_tree_valid)
rmse_best_tree_valid
36.33827179570092
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(valid_y, valid_y_best_pred)
Regression statistics Mean Error (ME) : 3.8074 Root Mean Squared Error (RMSE) : 36.3383 Mean Absolute Error (MAE) : 29.9632 Mean Percentage Error (MPE) : -49.3930 Mean Absolute Percentage Error (MAPE) : 80.5791
new_dnd_df = pd.read_csv("new_records_dnd.csv")
new_dnd_df
STR | DEX | CON | INT | WIS | CHA | |
---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 |
1 | 17 | 9 | 17 | 18 | 11 | 7 |
new_records_dnd_small_pred = small_tree.predict(new_dnd_df)
new_records_dnd_small_pred
array([69.37151703, 40.09090909])
import pandas as pd
dnd_small_tree_prediction_df = pd.DataFrame(new_records_dnd_small_pred,
columns = ["Prediction"])
dnd_small_tree_prediction_df
Prediction | |
---|---|
0 | 69.371517 |
1 | 40.090909 |
Merge with new data
new_dnd_df_with_prediction_small_tree = pd.concat((new_dnd_df, dnd_small_tree_prediction_df), axis = 1)
new_dnd_df_with_prediction_small_tree
# to export
# new_dnd_df_with_prediction_small_tree.to_csv("whatever_name.csv")
STR | DEX | CON | INT | WIS | CHA | Prediction | |
---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 69.371517 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 40.090909 |
new_records_dnd_best_pred = best_tree.predict(new_dnd_df)
new_records_dnd_best_pred
array([68.11142061, 40.09090909])
dnd_best_tree_prediction_df = pd.DataFrame(new_records_dnd_best_pred,
columns = ["Prediction"])
dnd_best_tree_prediction_df
# to export
# dnd_best_tree_prediction.to_csv("whatever_name.csv")
Prediction | |
---|---|
0 | 68.111421 |
1 | 40.090909 |
Merge with new data
new_dnd_df_with_prediction_best_tree = pd.concat((new_dnd_df, dnd_best_tree_prediction_df), axis = 1)
new_dnd_df_with_prediction_best_tree
STR | DEX | CON | INT | WIS | CHA | Prediction | |
---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 68.111421 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 40.090909 |
leaf_number_for_new = best_tree.apply(new_dnd_df)
leaf_number_for_new
array([5, 3], dtype=int64)
leaf_number_for_new_df = pd.DataFrame(leaf_number_for_new, columns = ["leaf_number"])
leaf_number_for_new_df
leaf_number | |
---|---|
0 | 5 |
1 | 3 |
new_dnd_df_with_prediction_best_tree_leaf_number = pd.concat((new_dnd_df_with_prediction_best_tree,
leaf_number_for_new_df), axis = 1)
new_dnd_df_with_prediction_best_tree_leaf_number
STR | DEX | CON | INT | WIS | CHA | Prediction | leaf_number | |
---|---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 68.111421 | 5 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 40.090909 | 3 |
leaf_number = pd.DataFrame(best_tree.apply(train_X), columns=["leaf_number"], index = train_y.index)
leaf_number
leaf_number | |
---|---|
650 | 5 |
479 | 5 |
271 | 5 |
647 | 5 |
307 | 5 |
... | ... |
445 | 5 |
414 | 5 |
70 | 5 |
429 | 5 |
236 | 5 |
440 rows × 1 columns
leaf_df = pd.concat([leaf_number, train_y], axis = 1)
leaf_df
leaf_number | HP | |
---|---|---|
650 | 5 | 117 |
479 | 5 | 120 |
271 | 5 | 72 |
647 | 5 | 117 |
307 | 5 | 100 |
... | ... | ... |
445 | 5 | 32 |
414 | 5 | 70 |
70 | 5 | 80 |
429 | 5 | 72 |
236 | 5 | 91 |
440 rows × 2 columns
leaf_max_df = leaf_df.groupby(by = "leaf_number").max()
leaf_max_df
HP | |
---|---|
leaf_number | |
2 | 120 |
3 | 112 |
5 | 150 |
6 | 140 |
leaf_max_df = leaf_max_df.rename(columns = {"HP": "Max_HP"})
leaf_max_df
Max_HP | |
---|---|
leaf_number | |
2 | 120 |
3 | 112 |
5 | 150 |
6 | 140 |
leaf_min_df = leaf_df.groupby(by = "leaf_number").min()
leaf_min_df
HP | |
---|---|
leaf_number | |
2 | 9 |
3 | 6 |
5 | 7 |
6 | 6 |
leaf_min_df = leaf_min_df.rename(columns = {"HP": "Min_HP"})
leaf_min_df
Min_HP | |
---|---|
leaf_number | |
2 | 9 |
3 | 6 |
5 | 7 |
6 | 6 |
leaf_std_df = leaf_df.groupby(by = "leaf_number").std()
leaf_std_df
HP | |
---|---|
leaf_number | |
2 | 27.924217 |
3 | 29.395350 |
5 | 37.572854 |
6 | 36.517976 |
leaf_std_df = leaf_std_df.rename(columns = {"HP": "std_HP"})
leaf_std_df
std_HP | |
---|---|
leaf_number | |
2 | 27.924217 |
3 | 29.395350 |
5 | 37.572854 |
6 | 36.517976 |
Merge to get range of predictions
new_dnd_df_with_prediction_best_tree_leaf_number_range = pd.merge(
pd.merge(
pd.merge(new_dnd_df_with_prediction_best_tree_leaf_number,leaf_max_df, how = "inner", on = "leaf_number"),
leaf_min_df, how = "inner", on = "leaf_number"),
leaf_std_df, how = "inner", on = "leaf_number")
new_dnd_df_with_prediction_best_tree_leaf_number_range
STR | DEX | CON | INT | WIS | CHA | Prediction | leaf_number | Max_HP | Min_HP | std_HP | |
---|---|---|---|---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 68.111421 | 5 | 150 | 7 | 37.572854 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 40.090909 | 3 | 112 | 6 | 29.395350 |
from sklearn.ensemble import RandomForestRegressor
rf = RandomForestRegressor(max_depth = 10, random_state = 666)
rf.fit(train_X, train_y)
RandomForestRegressor(max_depth=10, random_state=666)
train_y_pred_rf = rf.predict(train_X)
train_y_pred_rf
array([ 88.93415507, 103.4440601 , 69.47537521, 102.25756247, 90.6015338 , 72.34565388, 55.08351732, 61.08039394, 56.41217774, 55.12831922, 44.96499255, 50.06883454, 49.50479846, 71.23535818, 62.30711203, 38.00342379, 96.4850522 , 88.06 , 74.44247935, 39.28533981, 65.46959216, 35.64007071, 45.80873105, 45.26427884, 49.58661111, 81.29242571, 49.98355509, 85.12770696, 56.05009987, 79.14899567, 82.05170197, 48.303423 , 43.96298786, 63.78496467, 82.41620807, 70.45749862, 47.52008668, 115.36753846, 47.96596143, 32.5633898 , 75.86841179, 59.54143651, 77.72521177, 51.08989052, 21.19936315, 77.51343015, 62.72026195, 78.9389704 , 72.71332005, 77.33110815, 56.42808742, 89.85577218, 46.64963591, 85.13207896, 54.96862511, 56.87054614, 43.91997655, 81.48412106, 72.74496673, 86.26800778, 41.57996127, 99.21362527, 108.10668571, 57.29779173, 111.28059477, 53.60430871, 82.44322777, 82.40620381, 77.08339146, 91.56549427, 63.02749817, 54.35300364, 82.2662034 , 97.93201389, 73.0482379 , 75.82030556, 77.8802475 , 72.55873069, 37.73159091, 71.44797238, 54.50111558, 28.51136153, 81.70698526, 41.82350859, 54.03108768, 45.3973026 , 45.78339763, 61.36807 , 30.11053968, 82.13027778, 61.15893505, 101.57288501, 67.48741823, 26.93542019, 88.02408402, 37.45897826, 85.53040787, 84.95445402, 47.64463051, 67.75539365, 37.25539184, 91.29050568, 48.96801172, 78.61864337, 33.39780771, 77.58149604, 118.14853662, 55.76287724, 68.37750732, 69.79108469, 63.83683632, 66.15682478, 87.99372067, 54.63183916, 60.71442133, 33.09455779, 30.96605128, 60.4653631 , 65.59826143, 81.64426887, 60.19963427, 45.11525176, 49.20249957, 74.16530098, 73.19454255, 39.59228974, 54.49690542, 95.66990974, 35.73416306, 39.08250008, 79.12853771, 112.28103497, 68.98634541, 80.03228299, 67.99279785, 82.48577546, 85.0130968 , 70.083967 , 74.94647463, 72.72633574, 43.42748338, 82.81276538, 45.62634505, 106.61943651, 102.58836559, 80.38613257, 57.37117338, 33.88748496, 108.94866082, 53.35486688, 56.18110066, 34.60665282, 78.43217183, 64.39763654, 45.69437607, 68.45122386, 52.6406046 , 32.45174725, 48.9233243 , 29.83394958, 35.42764936, 109.65492857, 51.11937884, 69.64900305, 89.00004645, 58.59010944, 107.6301434 , 63.54540565, 52.22194841, 71.08996857, 67.97638999, 27.7934586 , 80.31229295, 71.33814287, 77.0330639 , 69.61092666, 48.67007949, 85.91442843, 41.74883263, 75.18138508, 77.07730257, 53.18972691, 66.60794683, 65.06615015, 62.96954681, 32.4638915 , 25.73599206, 69.36473102, 80.72062339, 69.3157973 , 79.51272729, 50.00133222, 53.77316321, 52.01618708, 51.0263894 , 49.63979731, 79.75040143, 42.25055932, 87.69675137, 64.09286381, 72.61794547, 92.84779201, 98.12629072, 86.48315375, 90.06800595, 57.20494978, 81.81610457, 51.62287151, 50.15992208, 32.62498061, 92.94263059, 68.53155923, 66.82723705, 70.01569548, 51.41607791, 62.33188215, 71.4641899 , 43.85667954, 105.44633669, 57.67718836, 77.16578731, 74.30349661, 53.70944274, 71.63464022, 75.0198326 , 69.05771916, 88.68029785, 44.97268484, 107.37958222, 23.36269481, 49.00191044, 58.6598738 , 78.08707725, 35.00298126, 84.92576119, 76.44292787, 94.72199013, 94.11666398, 59.28519292, 60.45311085, 65.7338112 , 95.74016892, 58.27571061, 77.21380488, 67.20433333, 49.15215437, 54.62181655, 55.11792308, 34.05189765, 64.52346018, 56.13229827, 81.13058491, 113.66492836, 76.42701323, 88.65096755, 57.28039899, 38.63262642, 48.68805405, 59.04782569, 102.37085767, 31.38396037, 52.23307442, 53.99473035, 80.48299547, 52.97368403, 68.96615525, 45.42954444, 61.45521545, 46.90742883, 53.06130937, 47.32722966, 76.63714672, 53.59161405, 63.17177048, 87.5244256 , 71.14728393, 89.9188027 , 68.04064219, 81.02650193, 108.56730092, 95.20197455, 40.77743978, 60.0310126 , 33.02766208, 75.63861591, 49.40175669, 33.68347467, 67.28065404, 52.59024541, 90.46681485, 96.96412321, 55.5304823 , 67.5809228 , 101.79697451, 43.97335836, 99.45552434, 69.62235006, 90.65557188, 109.08199052, 63.41604082, 38.49531318, 72.52384339, 90.26417945, 42.18824789, 48.05441438, 47.9575858 , 57.17206151, 30.2311224 , 71.52939374, 62.88909123, 68.43390846, 43.77445167, 95.76515043, 74.0386496 , 70.54859477, 55.06728316, 59.3325523 , 97.05169676, 80.23437332, 99.06046429, 66.32816194, 89.5258736 , 61.41168678, 98.17453733, 76.69305752, 60.18673817, 54.82025641, 57.32873945, 76.85930604, 45.96900252, 58.76426003, 55.51669719, 48.82278961, 43.78012896, 58.54019432, 72.18893748, 32.63400242, 61.5212612 , 72.75890845, 57.61661024, 65.39866234, 47.41540043, 77.30090306, 50.89748999, 61.50746927, 79.49715321, 41.20356734, 56.03800925, 94.15462125, 63.20338365, 95.28639722, 75.81232017, 32.09039495, 53.32924083, 59.63109694, 87.66349323, 71.51329672, 42.87189277, 48.26400166, 60.3417937 , 42.20664573, 43.00066797, 55.18067154, 76.27390707, 78.82137727, 85.451 , 69.72150089, 55.34246911, 80.58014985, 86.81124055, 83.98661409, 74.85431944, 42.61049311, 38.15734865, 85.9956654 , 43.26669312, 78.97732552, 36.26190476, 41.72199711, 53.07704693, 34.18176444, 53.47318892, 48.90472817, 58.57496796, 92.79668924, 78.4475137 , 53.60581926, 32.80670145, 49.42504221, 57.3471786 , 75.856811 , 35.04339654, 52.5227308 , 34.48859646, 59.79505366, 64.33654553, 54.4064929 , 24.49359101, 57.92275495, 57.89097697, 100.67209492, 88.89994674, 39.90751776, 58.73859727, 76.14315942, 62.97781913, 71.07756578, 97.56604108, 99.52729537, 44.67429365, 88.03345238, 68.59522164, 98.84517573, 60.62938004, 87.14872698, 47.58679546, 72.82187091, 65.94254085, 89.09710914, 61.96470938, 69.36748378, 60.7044839 , 81.60740097, 71.04032612, 54.72223377, 32.06326537, 67.85120587, 68.54966934, 90.59010516, 53.08508131, 114.62275726, 81.56617735, 116.21760762, 100.6281484 , 107.36454736, 54.93407498, 61.55746122, 74.83490502, 67.25198763, 75.9885209 ])
mse_rf_train = sklearn.metrics.mean_squared_error(train_y, train_y_pred_rf)
mse_rf_train
356.848877962145
# import math
rmse_rf_train = math.sqrt(mse_rf_train)
rmse_rf_train
18.89044409118391
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(train_y, train_y_pred_rf)
Regression statistics Mean Error (ME) : -0.3851 Root Mean Squared Error (RMSE) : 18.8904 Mean Absolute Error (MAE) : 15.6621 Mean Percentage Error (MPE) : -38.5278 Mean Absolute Percentage Error (MAPE) : 52.8393
valid_y_pred_rf = rf.predict(valid_X)
valid_y_pred_rf
array([85.99613623, 72.81197421, 57.14756602, 65.48415094, 64.23185644, 73.88297094, 75.63706658, 65.88536111, 77.38755932, 53.3611746 , 84.30540558, 81.38338597, 69.54557467, 76.4453801 , 77.30100974, 60.7457147 , 60.28698095, 54.54960764, 58.39220608, 79.51837662, 57.45235255, 44.49966493, 80.43081232, 80.9672802 , 79.12526316, 61.88250468, 68.48486885, 55.05741665, 54.63041239, 66.02297654, 74.05684091, 61.26967588, 63.94936452, 68.44916792, 79.89559722, 47.99857592, 77.75612193, 51.24598032, 91.96242819, 95.47827337, 72.99748216, 61.48166667, 74.24409524, 71.74566751, 65.74643356, 72.12616881, 59.34396605, 55.31680735, 75.01714982, 67.35775287, 62.03485455, 69.73391667, 60.42239718, 49.83764828, 78.65800622, 72.81000229, 84.1797256 , 73.57869858, 69.42034164, 54.17249117, 80.91616771, 31.0329006 , 58.56498167, 86.61381342, 53.14061718, 46.71116392, 51.25840051, 57.74422334, 70.76289968, 68.26478858, 53.12914791, 67.50639943, 58.15775315, 86.5268843 , 54.6770392 , 67.69249941, 59.79266884, 74.1011205 , 80.20981113, 47.86735296, 57.1792636 , 65.4681474 , 71.61567756, 64.21365135, 52.96507108, 49.00585282, 60.93468812, 66.42856805, 72.62375 , 61.39575291, 52.3028228 , 70.51631313, 61.68713413, 82.59035772, 62.97275458, 65.08313092, 62.98707668, 54.98747072, 74.79348167, 77.76127609, 53.28477675, 65.1869692 , 68.7451453 , 70.07741324, 62.38078571, 63.3365805 , 73.22574008, 64.41808607, 59.03357159, 56.65995647, 48.07595775, 66.59146202, 66.1300717 , 64.77757169, 58.36584383, 63.15126245, 53.69026718, 61.60511499, 68.70015011, 68.68720783, 61.64060036, 70.09213449, 73.00846138, 63.50531364, 54.47829147, 51.04458733, 75.76759166, 50.68129187, 85.91485798, 45.009602 , 74.03758542, 60.60588345, 66.72822863, 50.40984565, 50.91525236, 79.34774785, 74.27093714, 82.57503454, 84.23115898, 58.34441943, 73.50657518, 76.91086606, 94.43031653, 51.995 , 66.51659984, 59.95371429, 48.8428732 , 74.67547684, 54.12539596, 62.59740704, 38.82594481, 66.68127848, 59.69954936, 62.24238866, 51.24587194, 67.46654537, 50.08880019, 78.89875988, 70.45242092, 76.13153563, 47.54842012, 71.55815842, 77.32311666, 51.02301605, 45.95923124, 54.17618857, 72.52685606, 59.74554544, 61.51544039, 61.22026958, 74.06011706, 85.25158929, 66.50520371, 71.80522197, 46.78741758, 75.15755218, 64.5374446 , 56.70808703, 83.22767766, 59.33948239, 48.14300962, 80.07204362, 65.39221338, 61.34071397, 50.61837513, 76.18111707, 48.45899522, 82.61167027, 68.64237587, 69.85818255, 66.21509905, 82.97757268, 45.1973547 , 61.34981061, 69.38518179, 91.70802321, 70.23158458, 27.83719221, 59.99675521, 69.57862161, 70.99210766, 60.44286601, 65.59296299, 70.98130797, 65.47639993, 68.08674206, 68.73752076, 67.61220989, 74.31800031, 64.41745114, 67.19080289, 78.22315147, 76.38173478, 56.18366362, 52.71594479, 56.61547222, 73.43090887, 91.90809606, 59.87043638, 57.65063381, 51.46242433, 78.90791805, 60.54316234, 51.72708547, 75.8761541 , 71.22114286, 43.23641919, 65.66313809, 94.86817947, 67.21091462, 52.21235575, 72.71997829, 67.16296753, 60.14175456, 68.69688344, 49.47745978, 61.98980765, 61.37528073, 55.09736371, 66.38018356, 56.80601079, 64.55983748, 79.77329875, 49.94612653, 56.46429733, 70.81355007, 73.30237874, 69.29287879, 51.04525876, 57.27526319, 55.77392348, 70.01296176, 74.8949364 , 83.63937886, 62.96207777, 53.7513566 , 62.29764406, 69.50477559, 43.34733164, 80.23455717, 61.96857608, 69.37674323, 37.29 , 65.15223851, 72.95181574, 50.16673751, 75.97699246, 70.30796224, 64.95026118, 68.25145774, 61.41429168, 65.10816719, 64.88187943, 68.41595238, 57.63969008, 65.51860532, 88.77396387, 70.2144709 , 60.30475757, 55.3897939 , 61.63742355, 94.91610892, 72.07392157, 60.4653631 , 55.745187 , 73.56503645, 42.8781292 , 69.60228571, 84.44706084, 48.06643892, 71.47878175, 57.75287869, 66.18710983, 83.19506918])
mse_rf_valid = sklearn.metrics.mean_squared_error(valid_y, valid_y_pred_rf)
mse_rf_valid
1426.6006551155326
# import math
rmse_rf_valid = math.sqrt(mse_rf_valid)
rmse_rf_valid
37.7703674209761
# If using the dmba package, install it first:
# pip install dmba
# or
# conda install -c conda-forge dmba
# import dmba
# from dmba import regressionSummary
regressionSummary(valid_y, valid_y_pred_rf)
Regression statistics Mean Error (ME) : 3.4807 Root Mean Squared Error (RMSE) : 37.7704 Mean Absolute Error (MAE) : 31.1829 Mean Percentage Error (MPE) : -48.5668 Mean Absolute Percentage Error (MAPE) : 81.0687
Variable importance
var_importance = rf.feature_importances_
var_importance
array([0.17943126, 0.1634644 , 0.16296725, 0.16134494, 0.16488891, 0.16790324])
std = np.std([tree.feature_importances_ for tree in rf.estimators_], axis = 0)
std
array([0.04659177, 0.03802737, 0.0363087 , 0.03966035, 0.03344671, 0.03432209])
var_importance_df = pd.DataFrame({"variable": train_X.columns, "importance": var_importance, "std": std})
var_importance_df
variable | importance | std | |
---|---|---|---|
0 | STR | 0.179431 | 0.046592 |
1 | DEX | 0.163464 | 0.038027 |
2 | CON | 0.162967 | 0.036309 |
3 | INT | 0.161345 | 0.039660 |
4 | WIS | 0.164889 | 0.033447 |
5 | CHA | 0.167903 | 0.034322 |
var_importance_df_sorted = var_importance_df.sort_values("importance")
var_importance_df_sorted
variable | importance | std | |
---|---|---|---|
3 | INT | 0.161345 | 0.039660 |
2 | CON | 0.162967 | 0.036309 |
1 | DEX | 0.163464 | 0.038027 |
4 | WIS | 0.164889 | 0.033447 |
5 | CHA | 0.167903 | 0.034322 |
0 | STR | 0.179431 | 0.046592 |
import matplotlib.pyplot as plt
var_importance_plot = var_importance_df_sorted.plot(kind = "barh", xerr = "std", x = "variable", legend = False)
var_importance_plot.set_ylabel("")
var_importance_plot.set_xlabel("Importance")
plt.show()
new_records_dnd_rf_pred = rf.predict(new_dnd_df)
new_records_dnd_rf_pred
array([64.32568489, 70.02311429])
dnd_rf_prediction_df = pd.DataFrame(new_records_dnd_rf_pred,
columns = ["Prediction"])
dnd_rf_prediction_df
Prediction | |
---|---|
0 | 64.325685 |
1 | 70.023114 |
Merge with new data
new_dnd_df_with_prediction_rf = pd.concat((new_dnd_df, dnd_rf_prediction_df), axis = 1)
new_dnd_df_with_prediction_rf
# to export
# new_dnd_df_with_prediction_rf.to_csv("whatever_name.csv")
STR | DEX | CON | INT | WIS | CHA | Prediction | |
---|---|---|---|---|---|---|---|
0 | 9 | 17 | 8 | 13 | 16 | 15 | 64.325685 |
1 | 17 | 9 | 17 | 18 | 11 | 7 | 70.023114 |