Every rose has its thorn
Just like every night has its dawn
And every cowboy sings a sad, sad song
Every rose has its thorn
pokemon <- read.csv("pokemon_2.csv", header = TRUE)
Explore the data.
head(pokemon)
## abilities against_bug against_dark against_dragon
## 1 ['Overgrow', 'Chlorophyll'] 1.00 1 1
## 2 ['Overgrow', 'Chlorophyll'] 1.00 1 1
## 3 ['Overgrow', 'Chlorophyll'] 1.00 1 1
## 4 ['Blaze', 'Solar Power'] 0.50 1 1
## 5 ['Blaze', 'Solar Power'] 0.50 1 1
## 6 ['Blaze', 'Solar Power'] 0.25 1 1
## against_electric against_fairy against_fight against_fire against_flying
## 1 0.5 0.5 0.5 2.0 2
## 2 0.5 0.5 0.5 2.0 2
## 3 0.5 0.5 0.5 2.0 2
## 4 1.0 0.5 1.0 0.5 1
## 5 1.0 0.5 1.0 0.5 1
## 6 2.0 0.5 0.5 0.5 1
## against_ghost against_grass against_ground against_ice against_normal
## 1 1 0.25 1 2.0 1
## 2 1 0.25 1 2.0 1
## 3 1 0.25 1 2.0 1
## 4 1 0.50 2 0.5 1
## 5 1 0.50 2 0.5 1
## 6 1 0.25 0 1.0 1
## against_poison against_psychic against_rock against_steel against_water
## 1 1 2 1 1.0 0.5
## 2 1 2 1 1.0 0.5
## 3 1 2 1 1.0 0.5
## 4 1 1 2 0.5 2.0
## 5 1 1 2 0.5 2.0
## 6 1 1 4 0.5 2.0
## attack base_egg_steps base_happiness base_total capture_rate
## 1 49 5120 70 318 45
## 2 62 5120 70 405 45
## 3 100 5120 70 625 45
## 4 52 5120 70 309 45
## 5 64 5120 70 405 45
## 6 104 5120 70 634 45
## classfication defence experience_growth height_m hp
## 1 Seed Pok\xed\xa9mon 49 1059860 0.7 45
## 2 Seed Pok\xed\xa9mon 63 1059860 1.0 60
## 3 Seed Pok\xed\xa9mon 123 1059860 2.0 80
## 4 Lizard Pok\xed\xa9mon 43 1059860 0.6 39
## 5 Flame Pok\xed\xa9mon 58 1059860 1.1 58
## 6 Flame Pok\xed\xa9mon 78 1059860 1.7 78
## japanese_name name percentage_male
## 1 Fushigidane܀\xb4܉\x87܉\xac܀\xf3܀\x8d Bulbasaur 88.1
## 2 Fushigisou܀\xb4܉\x87܉\xac܉_܉_ Ivysaur 88.1
## 3 Fushigibana܀\xb4܉\x87܉\xac܀\x90܀_ Venusaur 88.1
## 4 Hitokage܀\xcd܀\x9a܉\x82܉_ Charmander 88.1
## 5 Lizardo܀\xc8܉_܀_܀\x8a Charmeleon 88.1
## 6 Lizardon܀\xc8܉_܀_܀\x8a܀_ Charizard 88.1
## pokedex_number sp_attack sp_defence speed type1 type2 weight_kg generation
## 1 1 65 65 45 grass poison 6.9 1
## 2 2 80 80 60 grass poison 13.0 1
## 3 3 122 120 80 grass poison 100.0 1
## 4 4 60 50 65 fire 8.5 1
## 5 5 80 65 80 fire 19.0 1
## 6 6 159 115 100 fire flying 90.5 1
## is_legendary train
## 1 0 1
## 2 0 1
## 3 0 0
## 4 0 0
## 5 0 0
## 6 0 0
str(pokemon)
## 'data.frame': 801 obs. of 42 variables:
## $ abilities : chr "['Overgrow', 'Chlorophyll']" "['Overgrow', 'Chlorophyll']" "['Overgrow', 'Chlorophyll']" "['Blaze', 'Solar Power']" ...
## $ against_bug : num 1 1 1 0.5 0.5 0.25 1 1 1 1 ...
## $ against_dark : num 1 1 1 1 1 1 1 1 1 1 ...
## $ against_dragon : num 1 1 1 1 1 1 1 1 1 1 ...
## $ against_electric : num 0.5 0.5 0.5 1 1 2 2 2 2 1 ...
## $ against_fairy : num 0.5 0.5 0.5 0.5 0.5 0.5 1 1 1 1 ...
## $ against_fight : num 0.5 0.5 0.5 1 1 0.5 1 1 1 0.5 ...
## $ against_fire : num 2 2 2 0.5 0.5 0.5 0.5 0.5 0.5 2 ...
## $ against_flying : num 2 2 2 1 1 1 1 1 1 2 ...
## $ against_ghost : num 1 1 1 1 1 1 1 1 1 1 ...
## $ against_grass : num 0.25 0.25 0.25 0.5 0.5 0.25 2 2 2 0.5 ...
## $ against_ground : num 1 1 1 2 2 0 1 1 1 0.5 ...
## $ against_ice : num 2 2 2 0.5 0.5 1 0.5 0.5 0.5 1 ...
## $ against_normal : num 1 1 1 1 1 1 1 1 1 1 ...
## $ against_poison : num 1 1 1 1 1 1 1 1 1 1 ...
## $ against_psychic : num 2 2 2 1 1 1 1 1 1 1 ...
## $ against_rock : num 1 1 1 2 2 4 1 1 1 2 ...
## $ against_steel : num 1 1 1 0.5 0.5 0.5 0.5 0.5 0.5 1 ...
## $ against_water : num 0.5 0.5 0.5 2 2 2 0.5 0.5 0.5 1 ...
## $ attack : int 49 62 100 52 64 104 48 63 103 30 ...
## $ base_egg_steps : int 5120 5120 5120 5120 5120 5120 5120 5120 5120 3840 ...
## $ base_happiness : int 70 70 70 70 70 70 70 70 70 70 ...
## $ base_total : int 318 405 625 309 405 634 314 405 630 195 ...
## $ capture_rate : chr "45" "45" "45" "45" ...
## $ classfication : chr "Seed Pok\xed\xa9mon" "Seed Pok\xed\xa9mon" "Seed Pok\xed\xa9mon" "Lizard Pok\xed\xa9mon" ...
## $ defence : int 49 63 123 43 58 78 65 80 120 35 ...
## $ experience_growth: int 1059860 1059860 1059860 1059860 1059860 1059860 1059860 1059860 1059860 1000000 ...
## $ height_m : num 0.7 1 2 0.6 1.1 1.7 0.5 1 1.6 0.3 ...
## $ hp : int 45 60 80 39 58 78 44 59 79 45 ...
## $ japanese_name : chr "Fushigidane܀\xb4܉\x87܉\xac܀\xf3܀\x8d" "Fushigisou܀\xb4܉\x87܉\xac܉_܉_" "Fushigibana܀\xb4܉\x87܉\xac܀\x90܀_" "Hitokage܀\xcd܀\x9a܉\x82܉_" ...
## $ name : chr "Bulbasaur" "Ivysaur" "Venusaur" "Charmander" ...
## $ percentage_male : num 88.1 88.1 88.1 88.1 88.1 88.1 88.1 88.1 88.1 50 ...
## $ pokedex_number : int 1 2 3 4 5 6 7 8 9 10 ...
## $ sp_attack : int 65 80 122 60 80 159 50 65 135 20 ...
## $ sp_defence : int 65 80 120 50 65 115 64 80 115 20 ...
## $ speed : int 45 60 80 65 80 100 43 58 78 45 ...
## $ type1 : chr "grass" "grass" "grass" "fire" ...
## $ type2 : chr "poison" "poison" "poison" "" ...
## $ weight_kg : num 6.9 13 100 8.5 19 90.5 9 22.5 85.5 2.9 ...
## $ generation : int 1 1 1 1 1 1 1 1 1 1 ...
## $ is_legendary : int 0 0 0 0 0 0 0 0 0 0 ...
## $ train : int 1 1 0 0 0 0 0 0 0 0 ...
summary(pokemon)
## abilities against_bug against_dark against_dragon
## Length:801 Min. :0.2500 Min. :0.250 Min. :0.0000
## Class :character 1st Qu.:0.5000 1st Qu.:1.000 1st Qu.:1.0000
## Mode :character Median :1.0000 Median :1.000 Median :1.0000
## Mean :0.9963 Mean :1.057 Mean :0.9688
## 3rd Qu.:1.0000 3rd Qu.:1.000 3rd Qu.:1.0000
## Max. :4.0000 Max. :4.000 Max. :2.0000
##
## against_electric against_fairy against_fight against_fire
## Min. :0.000 Min. :0.250 Min. :0.000 Min. :0.250
## 1st Qu.:0.500 1st Qu.:1.000 1st Qu.:0.500 1st Qu.:0.500
## Median :1.000 Median :1.000 Median :1.000 Median :1.000
## Mean :1.074 Mean :1.069 Mean :1.066 Mean :1.135
## 3rd Qu.:1.000 3rd Qu.:1.000 3rd Qu.:1.000 3rd Qu.:2.000
## Max. :4.000 Max. :4.000 Max. :4.000 Max. :4.000
##
## against_flying against_ghost against_grass against_ground
## Min. :0.250 Min. :0.000 Min. :0.250 Min. :0.000
## 1st Qu.:1.000 1st Qu.:1.000 1st Qu.:0.500 1st Qu.:1.000
## Median :1.000 Median :1.000 Median :1.000 Median :1.000
## Mean :1.193 Mean :0.985 Mean :1.034 Mean :1.098
## 3rd Qu.:1.000 3rd Qu.:1.000 3rd Qu.:1.000 3rd Qu.:1.000
## Max. :4.000 Max. :4.000 Max. :4.000 Max. :4.000
##
## against_ice against_normal against_poison against_psychic
## Min. :0.250 Min. :0.000 Min. :0.0000 Min. :0.000
## 1st Qu.:0.500 1st Qu.:1.000 1st Qu.:0.5000 1st Qu.:1.000
## Median :1.000 Median :1.000 Median :1.0000 Median :1.000
## Mean :1.208 Mean :0.887 Mean :0.9753 Mean :1.005
## 3rd Qu.:2.000 3rd Qu.:1.000 3rd Qu.:1.0000 3rd Qu.:1.000
## Max. :4.000 Max. :1.000 Max. :4.0000 Max. :4.000
##
## against_rock against_steel against_water attack
## Min. :0.25 Min. :0.2500 Min. :0.250 Min. : 5.00
## 1st Qu.:1.00 1st Qu.:0.5000 1st Qu.:0.500 1st Qu.: 55.00
## Median :1.00 Median :1.0000 Median :1.000 Median : 75.00
## Mean :1.25 Mean :0.9835 Mean :1.058 Mean : 77.86
## 3rd Qu.:2.00 3rd Qu.:1.0000 3rd Qu.:1.000 3rd Qu.:100.00
## Max. :4.00 Max. :4.0000 Max. :4.000 Max. :185.00
##
## base_egg_steps base_happiness base_total capture_rate
## Min. : 1280 Min. : 0.00 Min. :180.0 Length:801
## 1st Qu.: 5120 1st Qu.: 70.00 1st Qu.:320.0 Class :character
## Median : 5120 Median : 70.00 Median :435.0 Mode :character
## Mean : 7191 Mean : 65.36 Mean :428.4
## 3rd Qu.: 6400 3rd Qu.: 70.00 3rd Qu.:505.0
## Max. :30720 Max. :140.00 Max. :780.0
##
## classfication defence experience_growth height_m
## Length:801 Min. : 5.00 Min. : 600000 Min. : 0.100
## Class :character 1st Qu.: 50.00 1st Qu.:1000000 1st Qu.: 0.600
## Mode :character Median : 70.00 Median :1000000 Median : 1.000
## Mean : 73.01 Mean :1054996 Mean : 1.164
## 3rd Qu.: 90.00 3rd Qu.:1059860 3rd Qu.: 1.500
## Max. :230.00 Max. :1640000 Max. :14.500
## NA's :20
## hp japanese_name name percentage_male
## Min. : 1.00 Length:801 Length:801 Min. : 0.00
## 1st Qu.: 50.00 Class :character Class :character 1st Qu.: 50.00
## Median : 65.00 Mode :character Mode :character Median : 50.00
## Mean : 68.96 Mean : 55.16
## 3rd Qu.: 80.00 3rd Qu.: 50.00
## Max. :255.00 Max. :100.00
## NA's :98
## pokedex_number sp_attack sp_defence speed
## Min. : 1 Min. : 10.00 Min. : 20.00 Min. : 5.00
## 1st Qu.:201 1st Qu.: 45.00 1st Qu.: 50.00 1st Qu.: 45.00
## Median :401 Median : 65.00 Median : 66.00 Median : 65.00
## Mean :401 Mean : 71.31 Mean : 70.91 Mean : 66.33
## 3rd Qu.:601 3rd Qu.: 91.00 3rd Qu.: 90.00 3rd Qu.: 85.00
## Max. :801 Max. :194.00 Max. :230.00 Max. :180.00
##
## type1 type2 weight_kg generation
## Length:801 Length:801 Min. : 0.10 Min. :1.00
## Class :character Class :character 1st Qu.: 9.00 1st Qu.:2.00
## Mode :character Mode :character Median : 27.30 Median :4.00
## Mean : 61.38 Mean :3.69
## 3rd Qu.: 64.80 3rd Qu.:5.00
## Max. :999.90 Max. :7.00
## NA's :20
## is_legendary train
## Min. :0.00000 Min. :0.00000
## 1st Qu.:0.00000 1st Qu.:0.00000
## Median :0.00000 Median :0.00000
## Mean :0.08739 Mean :0.09988
## 3rd Qu.:0.00000 3rd Qu.:0.00000
## Max. :1.00000 Max. :1.00000
##
nrow(pokemon)
## [1] 801
table(pokemon$is_legendary)
##
## 0 1
## 731 70
names(pokemon)
## [1] "abilities" "against_bug" "against_dark"
## [4] "against_dragon" "against_electric" "against_fairy"
## [7] "against_fight" "against_fire" "against_flying"
## [10] "against_ghost" "against_grass" "against_ground"
## [13] "against_ice" "against_normal" "against_poison"
## [16] "against_psychic" "against_rock" "against_steel"
## [19] "against_water" "attack" "base_egg_steps"
## [22] "base_happiness" "base_total" "capture_rate"
## [25] "classfication" "defence" "experience_growth"
## [28] "height_m" "hp" "japanese_name"
## [31] "name" "percentage_male" "pokedex_number"
## [34] "sp_attack" "sp_defence" "speed"
## [37] "type1" "type2" "weight_kg"
## [40] "generation" "is_legendary" "train"
pokemon_df <- pokemon[, c(34:37, 41)]
head(pokemon_df)
## sp_attack sp_defence speed type1 is_legendary
## 1 65 65 45 grass 0
## 2 80 80 60 grass 0
## 3 122 120 80 grass 0
## 4 60 50 65 fire 0
## 5 80 65 80 fire 0
## 6 159 115 100 fire 0
nrow(pokemon_df)
## [1] 801
names(pokemon_df)
## [1] "sp_attack" "sp_defence" "speed" "type1" "is_legendary"
table(pokemon_df$is_legendary)
##
## 0 1
## 731 70
Our favourite seed :-)
set.seed(666)
Training-validation split.
Create the indices for the split This samples the row indices to split the data into training and validation.
train_index <- sample(1:nrow(pokemon_df), 0.6 * nrow(pokemon_df))
valid_index <- setdiff(1:nrow(pokemon_df), train_index)
Using the indices, create the training and validation sets This is similar in principle to splitting a data frame by row.
train_df <- pokemon_df[train_index, ]
valid_df <- pokemon_df[valid_index, ]
It is a good habit to check after splitting.
nrow(train_df)
## [1] 480
nrow(valid_df)
## [1] 321
table(train_df$is_legendary)
##
## 0 1
## 442 38
names(train_df)
## [1] "sp_attack" "sp_defence" "speed" "type1" "is_legendary"
Build the tree for the training set. This amounts to training the data. Other classification algorithms work too.
library(rpart)
class_tr_1 <- rpart(is_legendary ~ sp_attack + sp_defence +
speed + type1,
data = train_df, method = "class",
maxdepth = 10)
Plot the tree. Try different settings to tweak the format.
library(rpart.plot)
rpart.plot(class_tr_1, type = 5)
class_tr_1_train_predict <- predict(class_tr_1, train_df,
type = "class")
summary(class_tr_1_train_predict)
## 0 1
## 461 19
library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
class_tr_1_train_predict <- as.factor(class_tr_1_train_predict)
train_df$is_legendary <- as.factor(train_df$is_legendary)
confusionMatrix(class_tr_1_train_predict, train_df$is_legendary, positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 439 22
## 1 3 16
##
## Accuracy : 0.9479
## 95% CI : (0.9241, 0.966)
## No Information Rate : 0.9208
## P-Value [Acc > NIR] : 0.0134929
##
## Kappa : 0.537
##
## Mcnemar's Test P-Value : 0.0003182
##
## Sensitivity : 0.42105
## Specificity : 0.99321
## Pos Pred Value : 0.84211
## Neg Pred Value : 0.95228
## Prevalence : 0.07917
## Detection Rate : 0.03333
## Detection Prevalence : 0.03958
## Balanced Accuracy : 0.70713
##
## 'Positive' Class : 1
##
class_tr_1_valid_predict <- predict(class_tr_1, valid_df,
type = "class")
summary(class_tr_1_valid_predict)
## 0 1
## 304 17
Convert to factor for the confusion matrix.
class_tr_1_valid_predict <- as.factor(class_tr_1_valid_predict)
valid_df$is_legendary <- as.factor(valid_df$is_legendary)
Do the confusion matrix.
It is important to specify the preferred class. Otherwise, the results will be opposite.
confusionMatrix(class_tr_1_valid_predict, valid_df$is_legendary, positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 281 23
## 1 8 9
##
## Accuracy : 0.9034
## 95% CI : (0.8657, 0.9334)
## No Information Rate : 0.9003
## P-Value [Acc > NIR] : 0.47276
##
## Kappa : 0.3203
##
## Mcnemar's Test P-Value : 0.01192
##
## Sensitivity : 0.28125
## Specificity : 0.97232
## Pos Pred Value : 0.52941
## Neg Pred Value : 0.92434
## Prevalence : 0.09969
## Detection Rate : 0.02804
## Detection Prevalence : 0.05296
## Balanced Accuracy : 0.62678
##
## 'Positive' Class : 1
##
Evaluate the model.
Plot the ROC curve.
ROSE::roc.curve(valid_df$is_legendary, class_tr_1_valid_predict)
## Area under the curve (AUC): 0.627
Performance evaluation.
library(modelplotr)
## Package modelplotr loaded! Happy model plotting!
scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_df"),
dataset_labels = list("Validation data"),
models = list("class_tr_1"),
model_labels = list("Classification Tree"),
target_column = "is_legendary",
ntiles = 100)
## Warning: `select_()` was deprecated in dplyr 0.7.0.
## Please use `select()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
## ... scoring caret model "class_tr_1" on dataset "valid_df".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles)
## Warning: `group_by_()` was deprecated in dplyr 0.7.0.
## Please use `group_by()` instead.
## See vignette('programming') for more help
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
##
## No comparison specified, default values are used.
##
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "Classification Tree.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".
Cumulative gains for decision tree.
plot_cumgains(data = plot_input, highlight_ntile = 23,
custom_line_colors = "#1F8723")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative gains
## - When we select 23% with the highest probability according to model Classification Tree, this selection holds 34% of all 1 cases in Validation data.
##
##
Cumulative lift for decision tree.
plot_cumlift(data = plot_input, highlight_ntile = 23,
custom_line_colors = "#1F3387")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative lift
## - When we select 23% with the highest probability according to model Classification Tree in Validation data, this selection for 1 cases is 1.5 times better than selecting without a model.
##
##
Response plot for decision tree.
plot_response(data = plot_input)
Cumulative response plot for decision tree.
plot_cumresponse(data = plot_input)
Now, try the same model using weighted sample. This can be done using the ROSE package.
library(ROSE)
## Loaded ROSE 0.0-4
Create the weighted training df.
Factorise categorical variables for ROSE to work.
names(train_df)
## [1] "sp_attack" "sp_defence" "speed" "type1" "is_legendary"
# needed for update ROSE package
train_df$type1 <- as.factor(train_df$type1)
valid_df$type1 <- as.factor(valid_df$type1)
train_df_rose <- ROSE(is_legendary ~ sp_attack + sp_defence +
speed + type1,
data = train_df, seed = 666)$data
table(train_df_rose$is_legendary)
##
## 0 1
## 231 249
This is the same decision tree, except it uses the weighted training data.
class_tr_2 <- rpart(is_legendary ~ sp_attack + sp_defence +
speed + type1,
data = train_df_rose, method = "class",
maxdepth = 10)
rpart.plot(class_tr_2, type = 5)
Compute the predictions using model from the weighted training data.
class_tr_2_train_predict <- predict(class_tr_2, train_df_rose,
type = "class")
summary(class_tr_2_train_predict)
## 0 1
## 236 244
Convert to factor for the confusion matrix. Then generate the confusion matrix.
class_tr_2_train_predict <- as.factor(class_tr_2_train_predict)
train_df_rose$is_legendary <- as.factor(train_df_rose$is_legendary)
confusionMatrix(class_tr_2_train_predict, train_df_rose$is_legendary, positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 212 24
## 1 19 225
##
## Accuracy : 0.9104
## 95% CI : (0.8812, 0.9344)
## No Information Rate : 0.5188
## P-Value [Acc > NIR] : <2e-16
##
## Kappa : 0.8207
##
## Mcnemar's Test P-Value : 0.5419
##
## Sensitivity : 0.9036
## Specificity : 0.9177
## Pos Pred Value : 0.9221
## Neg Pred Value : 0.8983
## Prevalence : 0.5188
## Detection Rate : 0.4688
## Detection Prevalence : 0.5083
## Balanced Accuracy : 0.9107
##
## 'Positive' Class : 1
##
Use the new model generated from training the weighted training data.
class_tr_2_valid_predict <- predict(class_tr_2, valid_df,
type = "class")
summary(class_tr_2_valid_predict)
## 0 1
## 249 72
Convert to factor for the confusion matrix.
class_tr_2_valid_predict <- as.factor(class_tr_2_valid_predict)
valid_df$is_legendary <- as.factor(valid_df$is_legendary)
confusionMatrix(class_tr_2_valid_predict, valid_df$is_legendary, positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 238 11
## 1 51 21
##
## Accuracy : 0.8069
## 95% CI : (0.7594, 0.8486)
## No Information Rate : 0.9003
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.3084
##
## Mcnemar's Test P-Value : 7.308e-07
##
## Sensitivity : 0.65625
## Specificity : 0.82353
## Pos Pred Value : 0.29167
## Neg Pred Value : 0.95582
## Prevalence : 0.09969
## Detection Rate : 0.06542
## Detection Prevalence : 0.22430
## Balanced Accuracy : 0.73989
##
## 'Positive' Class : 1
##
Now, we can see if the results from the weighted training set is better.
Plot the ROC curve for the decision tree using weighted training data. The AUC is better.
ROSE::roc.curve(valid_df$is_legendary, class_tr_2_valid_predict)
## Area under the curve (AUC): 0.740
Performance.
Prepare the data for evaluation
library(modelplotr)
scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_df"),
dataset_labels = list("Validation data"),
models = list("class_tr_2"),
model_labels = list("Classification Tree (Balanced)"),
target_column = "is_legendary",
ntiles = 100)
## ... scoring caret model "class_tr_2" on dataset "valid_df".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles)
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
##
## No comparison specified, default values are used.
##
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "Classification Tree (Balanced).
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".
Cumulative gains for decision tree using weighted training data.
plot_cumgains(data = plot_input, highlight_ntile = 23,
custom_line_colors = "#1F8723")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative gains
## - When we select 23% with the highest probability according to model Classification Tree (Balanced), this selection holds 69% of all 1 cases in Validation data.
##
##
Cumulative lift for decision tree using weighted training data.
plot_cumlift(data = plot_input, highlight_ntile = 23,
custom_line_colors = "#1F3387")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative lift
## - When we select 23% with the highest probability according to model Classification Tree (Balanced) in Validation data, this selection for 1 cases is 3.0 times better than selecting without a model.
##
##
library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
##
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
##
## margin
class_tr_rf <- randomForest(is_legendary ~ sp_attack +
sp_defence + speed + type1,
data = train_df_rose, ntree = 300,
nodesize = 11, importance = TRUE)
The confusion matrix.
class_tr_rf_pred <- predict(class_tr_rf, train_df_rose)
confusionMatrix(class_tr_rf_pred, train_df_rose$is_legendary,
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 212 4
## 1 19 245
##
## Accuracy : 0.9521
## 95% CI : (0.929, 0.9694)
## No Information Rate : 0.5188
## P-Value [Acc > NIR] : < 2.2e-16
##
## Kappa : 0.9038
##
## Mcnemar's Test P-Value : 0.003509
##
## Sensitivity : 0.9839
## Specificity : 0.9177
## Pos Pred Value : 0.9280
## Neg Pred Value : 0.9815
## Prevalence : 0.5188
## Detection Rate : 0.5104
## Detection Prevalence : 0.5500
## Balanced Accuracy : 0.9508
##
## 'Positive' Class : 1
##
class_tr_rf_pred <- predict(class_tr_rf, valid_df)
confusionMatrix(class_tr_rf_pred, valid_df$is_legendary,
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 233 7
## 1 56 25
##
## Accuracy : 0.8037
## 95% CI : (0.756, 0.8458)
## No Information Rate : 0.9003
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.3495
##
## Mcnemar's Test P-Value : 1.472e-09
##
## Sensitivity : 0.78125
## Specificity : 0.80623
## Pos Pred Value : 0.30864
## Neg Pred Value : 0.97083
## Prevalence : 0.09969
## Detection Rate : 0.07788
## Detection Prevalence : 0.25234
## Balanced Accuracy : 0.79374
##
## 'Positive' Class : 1
##
ROC.
ROSE::roc.curve(valid_df$is_legendary, class_tr_rf_pred)
## Area under the curve (AUC): 0.794
Performance evaluation.
library(modelplotr)
scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_df"),
dataset_labels = list("Validation data"),
models = list("class_tr_rf"),
model_labels = list("Random Forest"),
target_column = "is_legendary",
ntiles = 100)
## ... scoring caret model "class_tr_rf" on dataset "valid_df".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles, select_targetclass = "1")
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
##
## No comparison specified, default values are used.
##
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "Random Forest.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".
Cumulative gains for kNN using weighted training data.
plot_cumgains(data = plot_input, highlight_ntile = 23,
custom_line_colors = "#DA5A0C")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative gains
## - When we select 23% with the highest probability according to model Random Forest, this selection holds 75% of all 1 cases in Validation data.
##
##
Cumulative lift for kNN using weighted training data.
plot_cumlift(data = plot_input, highlight_ntile = 23,
custom_line_colors = "#BEDA0C")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative lift
## - When we select 23% with the highest probability according to model Random Forest in Validation data, this selection for 1 cases is 3.3 times better than selecting without a model.
##
##
train_norm <- train_df
valid_norm <- valid_df
names(train_norm)
## [1] "sp_attack" "sp_defence" "speed" "type1" "is_legendary"
library(caret)
norm_values <- preProcess(train_df[, -c(4,5)],
method = c("center",
"scale"))
train_norm[, -c(4,5)] <- predict(norm_values,
train_df[, -c(4,5)])
head(train_norm)
## sp_attack sp_defence speed type1 is_legendary
## 574 -0.5003024 -0.176717360 -0.7288216 psychic 0
## 638 0.5765817 0.076597627 1.4176106 steel 1
## 608 0.7304222 -0.357656636 -0.3881181 ghost 0
## 123 -0.5003024 0.366100469 1.3153996 bug 0
## 540 -0.9618241 -0.357656636 -0.8310327 bug 0
## 654 0.5765817 0.004221916 0.2251483 fire 0
valid_norm[, -c(4,5)] <- predict(norm_values,
valid_df[, -c(4,5)])
head(valid_norm)
## sp_attack sp_defence speed type1 is_legendary
## 2 0.2689005 0.3661005 -0.21776634 grass 0
## 3 1.5611613 1.8136147 0.46364073 grass 0
## 4 -0.3464618 -0.7195352 -0.04741458 fire 0
## 5 0.2689005 -0.1767174 0.46364073 fire 0
## 12 0.5765817 0.3661005 0.12293719 bug 0
## 13 -1.5771864 -1.8051708 -0.55846988 bug 0
knn_model <- caret::knn3(is_legendary ~ ., data = train_norm, k = 15)
knn_model
## 15-nearest neighbor model
## Training set outcome distribution:
##
## 0 1
## 442 38
knn_pred_train <- predict(knn_model, newdata =
train_norm[,-c(5)],
type = "class")
head(knn_pred_train)
## [1] 0 0 0 0 0 0
## Levels: 0 1
confusionMatrix(knn_pred_train, as.factor(train_norm[, 5]),
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 439 23
## 1 3 15
##
## Accuracy : 0.9458
## 95% CI : (0.9216, 0.9643)
## No Information Rate : 0.9208
## P-Value [Acc > NIR] : 0.0215752
##
## Kappa : 0.5108
##
## Mcnemar's Test P-Value : 0.0001944
##
## Sensitivity : 0.39474
## Specificity : 0.99321
## Pos Pred Value : 0.83333
## Neg Pred Value : 0.95022
## Prevalence : 0.07917
## Detection Rate : 0.03125
## Detection Prevalence : 0.03750
## Balanced Accuracy : 0.69397
##
## 'Positive' Class : 1
##
knn_pred_valid <- predict(knn_model,
newdata = valid_norm[, -c(5)],
type = "class")
head(knn_pred_valid)
## [1] 0 0 0 0 0 0
## Levels: 0 1
confusionMatrix(knn_pred_valid, as.factor(valid_norm[, 5]),
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 284 24
## 1 5 8
##
## Accuracy : 0.9097
## 95% CI : (0.8728, 0.9387)
## No Information Rate : 0.9003
## P-Value [Acc > NIR] : 0.3278143
##
## Kappa : 0.3162
##
## Mcnemar's Test P-Value : 0.0008302
##
## Sensitivity : 0.25000
## Specificity : 0.98270
## Pos Pred Value : 0.61538
## Neg Pred Value : 0.92208
## Prevalence : 0.09969
## Detection Rate : 0.02492
## Detection Prevalence : 0.04050
## Balanced Accuracy : 0.61635
##
## 'Positive' Class : 1
##
ROC for kNN on weighted data.
ROSE::roc.curve(valid_norm$is_legendary, knn_pred_valid)
## Area under the curve (AUC): 0.616
Performance evaluation.
library(modelplotr)
scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_norm"),
dataset_labels = list("Validation data"),
models = list("knn_model"),
model_labels = list("kNN"),
target_column = "is_legendary",
ntiles = 10)
## ... scoring caret model "knn_model" on dataset "valid_norm".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles, select_targetclass = "1")
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
##
## No comparison specified, default values are used.
##
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "kNN.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".
Cumulative gains for kNN using training data.
plot_cumgains(data = plot_input, highlight_ntile = 2)
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative gains
## - When we select 20% with the highest probability according to model kNN, this selection holds 78% of all 1 cases in Validation data.
##
##
Cumulative lift for kNN using training data.
plot_cumlift(data = plot_input, highlight_ntile = 2,
custom_line_colors = "#0022AA")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative lift
## - When we select 20% with the highest probability according to model kNN in Validation data, this selection for 1 cases is 3.9 times better than selecting without a model.
##
##
train_norm_2 <- train_df_rose
valid_norm_2 <- valid_df
names(train_norm_2)
## [1] "sp_attack" "sp_defence" "speed" "type1" "is_legendary"
library(caret)
norm_values_2 <- preProcess(train_df[, -c(4,5)],
method = c("center",
"scale"))
train_norm_2[, -c(4,5)] <- predict(norm_values_2,
train_df[, -c(4,5)])
head(train_norm_2)
## sp_attack sp_defence speed type1 is_legendary
## 1 -0.5003024 -0.176717360 -0.7288216 water 0
## 2 0.5765817 0.076597627 1.4176106 fighting 0
## 3 0.7304222 -0.357656636 -0.3881181 bug 0
## 4 -0.5003024 0.366100469 1.3153996 water 0
## 5 -0.9618241 -0.357656636 -0.8310327 grass 0
## 6 0.5765817 0.004221916 0.2251483 normal 0
valid_norm_2[, -c(4,5)] <- predict(norm_values_2,
valid_df[, -c(4,5)])
head(valid_norm_2)
## sp_attack sp_defence speed type1 is_legendary
## 2 0.2689005 0.3661005 -0.21776634 grass 0
## 3 1.5611613 1.8136147 0.46364073 grass 0
## 4 -0.3464618 -0.7195352 -0.04741458 fire 0
## 5 0.2689005 -0.1767174 0.46364073 fire 0
## 12 0.5765817 0.3661005 0.12293719 bug 0
## 13 -1.5771864 -1.8051708 -0.55846988 bug 0
knn_model_2 <- caret::knn3(is_legendary ~ ., data = train_norm_2, k = 15)
knn_model_2
## 15-nearest neighbor model
## Training set outcome distribution:
##
## 0 1
## 231 249
knn_pred_train_2 <- predict(knn_model_2, newdata =
train_norm_2[,-c(5)],
type = "class")
head(knn_pred_train_2)
## [1] 1 0 0 0 0 0
## Levels: 0 1
confusionMatrix(knn_pred_train_2, as.factor(train_norm_2[, 5]),
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 184 100
## 1 47 149
##
## Accuracy : 0.6938
## 95% CI : (0.6504, 0.7347)
## No Information Rate : 0.5188
## P-Value [Acc > NIR] : 4.877e-15
##
## Kappa : 0.3917
##
## Mcnemar's Test P-Value : 1.796e-05
##
## Sensitivity : 0.5984
## Specificity : 0.7965
## Pos Pred Value : 0.7602
## Neg Pred Value : 0.6479
## Prevalence : 0.5188
## Detection Rate : 0.3104
## Detection Prevalence : 0.4083
## Balanced Accuracy : 0.6975
##
## 'Positive' Class : 1
##
knn_pred_valid_2 <- predict(knn_model_2,
newdata = valid_norm_2[, -c(5)],
type = "class")
head(knn_pred_valid_2)
## [1] 0 0 1 1 0 1
## Levels: 0 1
confusionMatrix(knn_pred_valid_2, as.factor(valid_norm_2[, 5]),
positive = "1")
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 205 13
## 1 84 19
##
## Accuracy : 0.6978
## 95% CI : (0.6444, 0.7476)
## No Information Rate : 0.9003
## P-Value [Acc > NIR] : 1
##
## Kappa : 0.1526
##
## Mcnemar's Test P-Value : 1.182e-12
##
## Sensitivity : 0.59375
## Specificity : 0.70934
## Pos Pred Value : 0.18447
## Neg Pred Value : 0.94037
## Prevalence : 0.09969
## Detection Rate : 0.05919
## Detection Prevalence : 0.32087
## Balanced Accuracy : 0.65155
##
## 'Positive' Class : 1
##
ROC for kNN on weighted data.
ROSE::roc.curve(valid_norm_2$is_legendary, knn_pred_valid_2)
## Area under the curve (AUC): 0.652
Performance evaluation.
library(modelplotr)
scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_norm_2"),
dataset_labels = list("Validation data"),
models = list("knn_model_2"),
model_labels = list("kNN with balanced data"),
target_column = "is_legendary",
ntiles = 10)
## ... scoring caret model "knn_model_2" on dataset "valid_norm_2".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles, select_targetclass = "1")
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
##
## No comparison specified, default values are used.
##
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "kNN with balanced data.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".
Cumulative gains for kNN using weighted training data.
plot_cumgains(data = plot_input, highlight_ntile = 2)
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative gains
## - When we select 20% with the highest probability according to model kNN with balanced data, this selection holds 44% of all 1 cases in Validation data.
##
##
Cumulative lift for kNN using weighted training data.
plot_cumlift(data = plot_input, highlight_ntile = 2,
custom_line_colors = "#0022AA")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##
## Plot annotation for plot: Cumulative lift
## - When we select 20% with the highest probability according to model kNN with balanced data in Validation data, this selection for 1 cases is 2.2 times better than selecting without a model.
##
##