Directions

Every rose has its thorn

Just like every night has its dawn

And every cowboy sings a sad, sad song

Every rose has its thorn

Data for demo

Back to the spellbook

1. Load data

pokemon <- read.csv("pokemon_2.csv", header = TRUE)

Explore the data.

head(pokemon)
##                     abilities against_bug against_dark against_dragon
## 1 ['Overgrow', 'Chlorophyll']        1.00            1              1
## 2 ['Overgrow', 'Chlorophyll']        1.00            1              1
## 3 ['Overgrow', 'Chlorophyll']        1.00            1              1
## 4    ['Blaze', 'Solar Power']        0.50            1              1
## 5    ['Blaze', 'Solar Power']        0.50            1              1
## 6    ['Blaze', 'Solar Power']        0.25            1              1
##   against_electric against_fairy against_fight against_fire against_flying
## 1              0.5           0.5           0.5          2.0              2
## 2              0.5           0.5           0.5          2.0              2
## 3              0.5           0.5           0.5          2.0              2
## 4              1.0           0.5           1.0          0.5              1
## 5              1.0           0.5           1.0          0.5              1
## 6              2.0           0.5           0.5          0.5              1
##   against_ghost against_grass against_ground against_ice against_normal
## 1             1          0.25              1         2.0              1
## 2             1          0.25              1         2.0              1
## 3             1          0.25              1         2.0              1
## 4             1          0.50              2         0.5              1
## 5             1          0.50              2         0.5              1
## 6             1          0.25              0         1.0              1
##   against_poison against_psychic against_rock against_steel against_water
## 1              1               2            1           1.0           0.5
## 2              1               2            1           1.0           0.5
## 3              1               2            1           1.0           0.5
## 4              1               1            2           0.5           2.0
## 5              1               1            2           0.5           2.0
## 6              1               1            4           0.5           2.0
##   attack base_egg_steps base_happiness base_total capture_rate
## 1     49           5120             70        318           45
## 2     62           5120             70        405           45
## 3    100           5120             70        625           45
## 4     52           5120             70        309           45
## 5     64           5120             70        405           45
## 6    104           5120             70        634           45
##           classfication defence experience_growth height_m hp
## 1   Seed Pok\xed\xa9mon      49           1059860      0.7 45
## 2   Seed Pok\xed\xa9mon      63           1059860      1.0 60
## 3   Seed Pok\xed\xa9mon     123           1059860      2.0 80
## 4 Lizard Pok\xed\xa9mon      43           1059860      0.6 39
## 5  Flame Pok\xed\xa9mon      58           1059860      1.1 58
## 6  Flame Pok\xed\xa9mon      78           1059860      1.7 78
##                          japanese_name       name percentage_male
## 1 Fushigidane܀\xb4܉\x87܉\xac܀\xf3܀\x8d  Bulbasaur            88.1
## 2        Fushigisou܀\xb4܉\x87܉\xac܉_܉_    Ivysaur            88.1
## 3    Fushigibana܀\xb4܉\x87܉\xac܀\x90܀_   Venusaur            88.1
## 4            Hitokage܀\xcd܀\x9a܉\x82܉_ Charmander            88.1
## 5                Lizardo܀\xc8܉_܀_܀\x8a Charmeleon            88.1
## 6             Lizardon܀\xc8܉_܀_܀\x8a܀_  Charizard            88.1
##   pokedex_number sp_attack sp_defence speed type1  type2 weight_kg generation
## 1              1        65         65    45 grass poison       6.9          1
## 2              2        80         80    60 grass poison      13.0          1
## 3              3       122        120    80 grass poison     100.0          1
## 4              4        60         50    65  fire              8.5          1
## 5              5        80         65    80  fire             19.0          1
## 6              6       159        115   100  fire flying      90.5          1
##   is_legendary train
## 1            0     1
## 2            0     1
## 3            0     0
## 4            0     0
## 5            0     0
## 6            0     0
str(pokemon)
## 'data.frame':    801 obs. of  42 variables:
##  $ abilities        : chr  "['Overgrow', 'Chlorophyll']" "['Overgrow', 'Chlorophyll']" "['Overgrow', 'Chlorophyll']" "['Blaze', 'Solar Power']" ...
##  $ against_bug      : num  1 1 1 0.5 0.5 0.25 1 1 1 1 ...
##  $ against_dark     : num  1 1 1 1 1 1 1 1 1 1 ...
##  $ against_dragon   : num  1 1 1 1 1 1 1 1 1 1 ...
##  $ against_electric : num  0.5 0.5 0.5 1 1 2 2 2 2 1 ...
##  $ against_fairy    : num  0.5 0.5 0.5 0.5 0.5 0.5 1 1 1 1 ...
##  $ against_fight    : num  0.5 0.5 0.5 1 1 0.5 1 1 1 0.5 ...
##  $ against_fire     : num  2 2 2 0.5 0.5 0.5 0.5 0.5 0.5 2 ...
##  $ against_flying   : num  2 2 2 1 1 1 1 1 1 2 ...
##  $ against_ghost    : num  1 1 1 1 1 1 1 1 1 1 ...
##  $ against_grass    : num  0.25 0.25 0.25 0.5 0.5 0.25 2 2 2 0.5 ...
##  $ against_ground   : num  1 1 1 2 2 0 1 1 1 0.5 ...
##  $ against_ice      : num  2 2 2 0.5 0.5 1 0.5 0.5 0.5 1 ...
##  $ against_normal   : num  1 1 1 1 1 1 1 1 1 1 ...
##  $ against_poison   : num  1 1 1 1 1 1 1 1 1 1 ...
##  $ against_psychic  : num  2 2 2 1 1 1 1 1 1 1 ...
##  $ against_rock     : num  1 1 1 2 2 4 1 1 1 2 ...
##  $ against_steel    : num  1 1 1 0.5 0.5 0.5 0.5 0.5 0.5 1 ...
##  $ against_water    : num  0.5 0.5 0.5 2 2 2 0.5 0.5 0.5 1 ...
##  $ attack           : int  49 62 100 52 64 104 48 63 103 30 ...
##  $ base_egg_steps   : int  5120 5120 5120 5120 5120 5120 5120 5120 5120 3840 ...
##  $ base_happiness   : int  70 70 70 70 70 70 70 70 70 70 ...
##  $ base_total       : int  318 405 625 309 405 634 314 405 630 195 ...
##  $ capture_rate     : chr  "45" "45" "45" "45" ...
##  $ classfication    : chr  "Seed Pok\xed\xa9mon" "Seed Pok\xed\xa9mon" "Seed Pok\xed\xa9mon" "Lizard Pok\xed\xa9mon" ...
##  $ defence          : int  49 63 123 43 58 78 65 80 120 35 ...
##  $ experience_growth: int  1059860 1059860 1059860 1059860 1059860 1059860 1059860 1059860 1059860 1000000 ...
##  $ height_m         : num  0.7 1 2 0.6 1.1 1.7 0.5 1 1.6 0.3 ...
##  $ hp               : int  45 60 80 39 58 78 44 59 79 45 ...
##  $ japanese_name    : chr  "Fushigidane܀\xb4܉\x87܉\xac܀\xf3܀\x8d" "Fushigisou܀\xb4܉\x87܉\xac܉_܉_" "Fushigibana܀\xb4܉\x87܉\xac܀\x90܀_" "Hitokage܀\xcd܀\x9a܉\x82܉_" ...
##  $ name             : chr  "Bulbasaur" "Ivysaur" "Venusaur" "Charmander" ...
##  $ percentage_male  : num  88.1 88.1 88.1 88.1 88.1 88.1 88.1 88.1 88.1 50 ...
##  $ pokedex_number   : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ sp_attack        : int  65 80 122 60 80 159 50 65 135 20 ...
##  $ sp_defence       : int  65 80 120 50 65 115 64 80 115 20 ...
##  $ speed            : int  45 60 80 65 80 100 43 58 78 45 ...
##  $ type1            : chr  "grass" "grass" "grass" "fire" ...
##  $ type2            : chr  "poison" "poison" "poison" "" ...
##  $ weight_kg        : num  6.9 13 100 8.5 19 90.5 9 22.5 85.5 2.9 ...
##  $ generation       : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ is_legendary     : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ train            : int  1 1 0 0 0 0 0 0 0 0 ...
summary(pokemon)
##   abilities          against_bug      against_dark   against_dragon  
##  Length:801         Min.   :0.2500   Min.   :0.250   Min.   :0.0000  
##  Class :character   1st Qu.:0.5000   1st Qu.:1.000   1st Qu.:1.0000  
##  Mode  :character   Median :1.0000   Median :1.000   Median :1.0000  
##                     Mean   :0.9963   Mean   :1.057   Mean   :0.9688  
##                     3rd Qu.:1.0000   3rd Qu.:1.000   3rd Qu.:1.0000  
##                     Max.   :4.0000   Max.   :4.000   Max.   :2.0000  
##                                                                      
##  against_electric against_fairy   against_fight    against_fire  
##  Min.   :0.000    Min.   :0.250   Min.   :0.000   Min.   :0.250  
##  1st Qu.:0.500    1st Qu.:1.000   1st Qu.:0.500   1st Qu.:0.500  
##  Median :1.000    Median :1.000   Median :1.000   Median :1.000  
##  Mean   :1.074    Mean   :1.069   Mean   :1.066   Mean   :1.135  
##  3rd Qu.:1.000    3rd Qu.:1.000   3rd Qu.:1.000   3rd Qu.:2.000  
##  Max.   :4.000    Max.   :4.000   Max.   :4.000   Max.   :4.000  
##                                                                  
##  against_flying  against_ghost   against_grass   against_ground 
##  Min.   :0.250   Min.   :0.000   Min.   :0.250   Min.   :0.000  
##  1st Qu.:1.000   1st Qu.:1.000   1st Qu.:0.500   1st Qu.:1.000  
##  Median :1.000   Median :1.000   Median :1.000   Median :1.000  
##  Mean   :1.193   Mean   :0.985   Mean   :1.034   Mean   :1.098  
##  3rd Qu.:1.000   3rd Qu.:1.000   3rd Qu.:1.000   3rd Qu.:1.000  
##  Max.   :4.000   Max.   :4.000   Max.   :4.000   Max.   :4.000  
##                                                                 
##   against_ice    against_normal  against_poison   against_psychic
##  Min.   :0.250   Min.   :0.000   Min.   :0.0000   Min.   :0.000  
##  1st Qu.:0.500   1st Qu.:1.000   1st Qu.:0.5000   1st Qu.:1.000  
##  Median :1.000   Median :1.000   Median :1.0000   Median :1.000  
##  Mean   :1.208   Mean   :0.887   Mean   :0.9753   Mean   :1.005  
##  3rd Qu.:2.000   3rd Qu.:1.000   3rd Qu.:1.0000   3rd Qu.:1.000  
##  Max.   :4.000   Max.   :1.000   Max.   :4.0000   Max.   :4.000  
##                                                                  
##   against_rock  against_steel    against_water       attack      
##  Min.   :0.25   Min.   :0.2500   Min.   :0.250   Min.   :  5.00  
##  1st Qu.:1.00   1st Qu.:0.5000   1st Qu.:0.500   1st Qu.: 55.00  
##  Median :1.00   Median :1.0000   Median :1.000   Median : 75.00  
##  Mean   :1.25   Mean   :0.9835   Mean   :1.058   Mean   : 77.86  
##  3rd Qu.:2.00   3rd Qu.:1.0000   3rd Qu.:1.000   3rd Qu.:100.00  
##  Max.   :4.00   Max.   :4.0000   Max.   :4.000   Max.   :185.00  
##                                                                  
##  base_egg_steps  base_happiness     base_total    capture_rate      
##  Min.   : 1280   Min.   :  0.00   Min.   :180.0   Length:801        
##  1st Qu.: 5120   1st Qu.: 70.00   1st Qu.:320.0   Class :character  
##  Median : 5120   Median : 70.00   Median :435.0   Mode  :character  
##  Mean   : 7191   Mean   : 65.36   Mean   :428.4                     
##  3rd Qu.: 6400   3rd Qu.: 70.00   3rd Qu.:505.0                     
##  Max.   :30720   Max.   :140.00   Max.   :780.0                     
##                                                                     
##  classfication         defence       experience_growth    height_m     
##  Length:801         Min.   :  5.00   Min.   : 600000   Min.   : 0.100  
##  Class :character   1st Qu.: 50.00   1st Qu.:1000000   1st Qu.: 0.600  
##  Mode  :character   Median : 70.00   Median :1000000   Median : 1.000  
##                     Mean   : 73.01   Mean   :1054996   Mean   : 1.164  
##                     3rd Qu.: 90.00   3rd Qu.:1059860   3rd Qu.: 1.500  
##                     Max.   :230.00   Max.   :1640000   Max.   :14.500  
##                                                        NA's   :20      
##        hp         japanese_name          name           percentage_male 
##  Min.   :  1.00   Length:801         Length:801         Min.   :  0.00  
##  1st Qu.: 50.00   Class :character   Class :character   1st Qu.: 50.00  
##  Median : 65.00   Mode  :character   Mode  :character   Median : 50.00  
##  Mean   : 68.96                                         Mean   : 55.16  
##  3rd Qu.: 80.00                                         3rd Qu.: 50.00  
##  Max.   :255.00                                         Max.   :100.00  
##                                                         NA's   :98      
##  pokedex_number   sp_attack        sp_defence         speed       
##  Min.   :  1    Min.   : 10.00   Min.   : 20.00   Min.   :  5.00  
##  1st Qu.:201    1st Qu.: 45.00   1st Qu.: 50.00   1st Qu.: 45.00  
##  Median :401    Median : 65.00   Median : 66.00   Median : 65.00  
##  Mean   :401    Mean   : 71.31   Mean   : 70.91   Mean   : 66.33  
##  3rd Qu.:601    3rd Qu.: 91.00   3rd Qu.: 90.00   3rd Qu.: 85.00  
##  Max.   :801    Max.   :194.00   Max.   :230.00   Max.   :180.00  
##                                                                   
##     type1              type2             weight_kg        generation  
##  Length:801         Length:801         Min.   :  0.10   Min.   :1.00  
##  Class :character   Class :character   1st Qu.:  9.00   1st Qu.:2.00  
##  Mode  :character   Mode  :character   Median : 27.30   Median :4.00  
##                                        Mean   : 61.38   Mean   :3.69  
##                                        3rd Qu.: 64.80   3rd Qu.:5.00  
##                                        Max.   :999.90   Max.   :7.00  
##                                        NA's   :20                     
##   is_legendary         train        
##  Min.   :0.00000   Min.   :0.00000  
##  1st Qu.:0.00000   1st Qu.:0.00000  
##  Median :0.00000   Median :0.00000  
##  Mean   :0.08739   Mean   :0.09988  
##  3rd Qu.:0.00000   3rd Qu.:0.00000  
##  Max.   :1.00000   Max.   :1.00000  
## 
nrow(pokemon)
## [1] 801
table(pokemon$is_legendary)
## 
##   0   1 
## 731  70

1.1 Filter for only selected variables

names(pokemon)
##  [1] "abilities"         "against_bug"       "against_dark"     
##  [4] "against_dragon"    "against_electric"  "against_fairy"    
##  [7] "against_fight"     "against_fire"      "against_flying"   
## [10] "against_ghost"     "against_grass"     "against_ground"   
## [13] "against_ice"       "against_normal"    "against_poison"   
## [16] "against_psychic"   "against_rock"      "against_steel"    
## [19] "against_water"     "attack"            "base_egg_steps"   
## [22] "base_happiness"    "base_total"        "capture_rate"     
## [25] "classfication"     "defence"           "experience_growth"
## [28] "height_m"          "hp"                "japanese_name"    
## [31] "name"              "percentage_male"   "pokedex_number"   
## [34] "sp_attack"         "sp_defence"        "speed"            
## [37] "type1"             "type2"             "weight_kg"        
## [40] "generation"        "is_legendary"      "train"
pokemon_df <- pokemon[, c(34:37, 41)]
head(pokemon_df)
##   sp_attack sp_defence speed type1 is_legendary
## 1        65         65    45 grass            0
## 2        80         80    60 grass            0
## 3       122        120    80 grass            0
## 4        60         50    65  fire            0
## 5        80         65    80  fire            0
## 6       159        115   100  fire            0
nrow(pokemon_df)
## [1] 801
names(pokemon_df)
## [1] "sp_attack"    "sp_defence"   "speed"        "type1"        "is_legendary"
table(pokemon_df$is_legendary)
## 
##   0   1 
## 731  70

2. Training validation split

Our favourite seed :-)

set.seed(666)

Training-validation split.

Create the indices for the split This samples the row indices to split the data into training and validation.

train_index <- sample(1:nrow(pokemon_df), 0.6 * nrow(pokemon_df))
valid_index <- setdiff(1:nrow(pokemon_df), train_index)

Using the indices, create the training and validation sets This is similar in principle to splitting a data frame by row.

train_df <- pokemon_df[train_index, ]
valid_df <- pokemon_df[valid_index, ]

It is a good habit to check after splitting.

nrow(train_df)
## [1] 480
nrow(valid_df)
## [1] 321
table(train_df$is_legendary)
## 
##   0   1 
## 442  38
names(train_df)
## [1] "sp_attack"    "sp_defence"   "speed"        "type1"        "is_legendary"

3. Classification tree

Build the tree for the training set. This amounts to training the data. Other classification algorithms work too.

library(rpart)
class_tr_1 <- rpart(is_legendary ~ sp_attack + sp_defence + 
                      speed + type1,
                    data = train_df, method = "class",
                    maxdepth = 10)

Plot the tree. Try different settings to tweak the format.

library(rpart.plot)
rpart.plot(class_tr_1, type = 5)

3.1 Predict training set

class_tr_1_train_predict <- predict(class_tr_1, train_df,
                                  type = "class")

summary(class_tr_1_train_predict)
##   0   1 
## 461  19
library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
class_tr_1_train_predict <- as.factor(class_tr_1_train_predict)
train_df$is_legendary <- as.factor(train_df$is_legendary)
confusionMatrix(class_tr_1_train_predict, train_df$is_legendary, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 439  22
##          1   3  16
##                                          
##                Accuracy : 0.9479         
##                  95% CI : (0.9241, 0.966)
##     No Information Rate : 0.9208         
##     P-Value [Acc > NIR] : 0.0134929      
##                                          
##                   Kappa : 0.537          
##                                          
##  Mcnemar's Test P-Value : 0.0003182      
##                                          
##             Sensitivity : 0.42105        
##             Specificity : 0.99321        
##          Pos Pred Value : 0.84211        
##          Neg Pred Value : 0.95228        
##              Prevalence : 0.07917        
##          Detection Rate : 0.03333        
##    Detection Prevalence : 0.03958        
##       Balanced Accuracy : 0.70713        
##                                          
##        'Positive' Class : 1              
## 

3.2 Predict validation set

class_tr_1_valid_predict <- predict(class_tr_1, valid_df,
                                    type = "class")

summary(class_tr_1_valid_predict)
##   0   1 
## 304  17

Convert to factor for the confusion matrix.

class_tr_1_valid_predict <- as.factor(class_tr_1_valid_predict)
valid_df$is_legendary <- as.factor(valid_df$is_legendary)

Do the confusion matrix.

It is important to specify the preferred class. Otherwise, the results will be opposite.

confusionMatrix(class_tr_1_valid_predict, valid_df$is_legendary, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 281  23
##          1   8   9
##                                           
##                Accuracy : 0.9034          
##                  95% CI : (0.8657, 0.9334)
##     No Information Rate : 0.9003          
##     P-Value [Acc > NIR] : 0.47276         
##                                           
##                   Kappa : 0.3203          
##                                           
##  Mcnemar's Test P-Value : 0.01192         
##                                           
##             Sensitivity : 0.28125         
##             Specificity : 0.97232         
##          Pos Pred Value : 0.52941         
##          Neg Pred Value : 0.92434         
##              Prevalence : 0.09969         
##          Detection Rate : 0.02804         
##    Detection Prevalence : 0.05296         
##       Balanced Accuracy : 0.62678         
##                                           
##        'Positive' Class : 1               
## 

3.3 Model Evaluation

Evaluate the model.

Plot the ROC curve.

ROSE::roc.curve(valid_df$is_legendary, class_tr_1_valid_predict)

## Area under the curve (AUC): 0.627

Performance evaluation.

library(modelplotr)
## Package modelplotr loaded! Happy model plotting!
scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_df"),
                                               dataset_labels = list("Validation data"),
                                               models = list("class_tr_1"),
                                               model_labels = list("Classification Tree"),
                                               target_column = "is_legendary",
                                               ntiles = 100)
## Warning: `select_()` was deprecated in dplyr 0.7.0.
## Please use `select()` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
## ... scoring caret model "class_tr_1" on dataset "valid_df".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles)
## Warning: `group_by_()` was deprecated in dplyr 0.7.0.
## Please use `group_by()` instead.
## See vignette('programming') for more help
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was generated.
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
## 
## No comparison specified, default values are used. 
## 
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "Classification Tree.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".

Cumulative gains for decision tree.

plot_cumgains(data = plot_input, highlight_ntile = 23,
              custom_line_colors = "#1F8723")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative gains
## - When we select 23% with the highest probability according to model Classification Tree, this selection holds 34% of all 1 cases in Validation data. 
##  
## 

Cumulative lift for decision tree.

plot_cumlift(data = plot_input, highlight_ntile = 23,
              custom_line_colors = "#1F3387")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative lift
## - When we select 23% with the highest probability according to model Classification Tree in Validation data, this selection for 1 cases is 1.5 times better than selecting without a model. 
##  
## 

Response plot for decision tree.

plot_response(data = plot_input)

Cumulative response plot for decision tree.

plot_cumresponse(data = plot_input)

4. Weighted sampling

Now, try the same model using weighted sample. This can be done using the ROSE package.

library(ROSE)
## Loaded ROSE 0.0-4

Create the weighted training df.

Factorise categorical variables for ROSE to work.

names(train_df)
## [1] "sp_attack"    "sp_defence"   "speed"        "type1"        "is_legendary"
# needed for update ROSE package
train_df$type1 <- as.factor(train_df$type1)
valid_df$type1 <- as.factor(valid_df$type1)
train_df_rose <- ROSE(is_legendary ~ sp_attack + sp_defence + 
                        speed + type1,
                      data = train_df, seed = 666)$data

table(train_df_rose$is_legendary)
## 
##   0   1 
## 231 249

5. Weighted data decision tree

This is the same decision tree, except it uses the weighted training data.

class_tr_2 <- rpart(is_legendary ~ sp_attack + sp_defence + 
                      speed + type1,
                    data = train_df_rose, method = "class",
                    maxdepth = 10)

rpart.plot(class_tr_2, type = 5)

5.1 Predict training set

Compute the predictions using model from the weighted training data.

class_tr_2_train_predict <- predict(class_tr_2, train_df_rose,
                                    type = "class")

summary(class_tr_2_train_predict)
##   0   1 
## 236 244

Convert to factor for the confusion matrix. Then generate the confusion matrix.

class_tr_2_train_predict <- as.factor(class_tr_2_train_predict)
train_df_rose$is_legendary <- as.factor(train_df_rose$is_legendary)
confusionMatrix(class_tr_2_train_predict, train_df_rose$is_legendary, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 212  24
##          1  19 225
##                                           
##                Accuracy : 0.9104          
##                  95% CI : (0.8812, 0.9344)
##     No Information Rate : 0.5188          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.8207          
##                                           
##  Mcnemar's Test P-Value : 0.5419          
##                                           
##             Sensitivity : 0.9036          
##             Specificity : 0.9177          
##          Pos Pred Value : 0.9221          
##          Neg Pred Value : 0.8983          
##              Prevalence : 0.5188          
##          Detection Rate : 0.4688          
##    Detection Prevalence : 0.5083          
##       Balanced Accuracy : 0.9107          
##                                           
##        'Positive' Class : 1               
## 

5.2 Predict validation set

Use the new model generated from training the weighted training data.

class_tr_2_valid_predict <- predict(class_tr_2, valid_df,
                                    type = "class")

summary(class_tr_2_valid_predict)
##   0   1 
## 249  72

Convert to factor for the confusion matrix.

class_tr_2_valid_predict <- as.factor(class_tr_2_valid_predict)
valid_df$is_legendary <- as.factor(valid_df$is_legendary)
confusionMatrix(class_tr_2_valid_predict, valid_df$is_legendary, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 238  11
##          1  51  21
##                                           
##                Accuracy : 0.8069          
##                  95% CI : (0.7594, 0.8486)
##     No Information Rate : 0.9003          
##     P-Value [Acc > NIR] : 1               
##                                           
##                   Kappa : 0.3084          
##                                           
##  Mcnemar's Test P-Value : 7.308e-07       
##                                           
##             Sensitivity : 0.65625         
##             Specificity : 0.82353         
##          Pos Pred Value : 0.29167         
##          Neg Pred Value : 0.95582         
##              Prevalence : 0.09969         
##          Detection Rate : 0.06542         
##    Detection Prevalence : 0.22430         
##       Balanced Accuracy : 0.73989         
##                                           
##        'Positive' Class : 1               
## 

5.3 Model evaluation

Now, we can see if the results from the weighted training set is better.

Plot the ROC curve for the decision tree using weighted training data. The AUC is better.

ROSE::roc.curve(valid_df$is_legendary, class_tr_2_valid_predict)

## Area under the curve (AUC): 0.740

Performance.

Prepare the data for evaluation

library(modelplotr)


scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_df"),
                                               dataset_labels = list("Validation data"),
                                               models = list("class_tr_2"),
                                               model_labels = list("Classification Tree (Balanced)"),
                                               target_column = "is_legendary",
                                               ntiles = 100)
## ... scoring caret model "class_tr_2" on dataset "valid_df".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles)
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
## 
## No comparison specified, default values are used. 
## 
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "Classification Tree (Balanced).
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".

Cumulative gains for decision tree using weighted training data.

plot_cumgains(data = plot_input, highlight_ntile = 23,
              custom_line_colors = "#1F8723")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative gains
## - When we select 23% with the highest probability according to model Classification Tree (Balanced), this selection holds 69% of all 1 cases in Validation data. 
##  
## 

Cumulative lift for decision tree using weighted training data.

plot_cumlift(data = plot_input, highlight_ntile = 23,
              custom_line_colors = "#1F3387")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative lift
## - When we select 23% with the highest probability according to model Classification Tree (Balanced) in Validation data, this selection for 1 cases is 3.0 times better than selecting without a model. 
##  
## 

6. Random Forest

6.1 Improved tree

library(randomForest)
## randomForest 4.7-1.1
## Type rfNews() to see new features/changes/bug fixes.
## 
## Attaching package: 'randomForest'
## The following object is masked from 'package:ggplot2':
## 
##     margin
class_tr_rf <- randomForest(is_legendary ~ sp_attack + 
                                    sp_defence + speed + type1,
                    data = train_df_rose, ntree = 300,
                            nodesize = 11, importance = TRUE)

6.2 Model evaluation

The confusion matrix.

class_tr_rf_pred <- predict(class_tr_rf, train_df_rose)

confusionMatrix(class_tr_rf_pred, train_df_rose$is_legendary, 
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 212   4
##          1  19 245
##                                          
##                Accuracy : 0.9521         
##                  95% CI : (0.929, 0.9694)
##     No Information Rate : 0.5188         
##     P-Value [Acc > NIR] : < 2.2e-16      
##                                          
##                   Kappa : 0.9038         
##                                          
##  Mcnemar's Test P-Value : 0.003509       
##                                          
##             Sensitivity : 0.9839         
##             Specificity : 0.9177         
##          Pos Pred Value : 0.9280         
##          Neg Pred Value : 0.9815         
##              Prevalence : 0.5188         
##          Detection Rate : 0.5104         
##    Detection Prevalence : 0.5500         
##       Balanced Accuracy : 0.9508         
##                                          
##        'Positive' Class : 1              
## 
class_tr_rf_pred <- predict(class_tr_rf, valid_df)

confusionMatrix(class_tr_rf_pred, valid_df$is_legendary, 
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 233   7
##          1  56  25
##                                          
##                Accuracy : 0.8037         
##                  95% CI : (0.756, 0.8458)
##     No Information Rate : 0.9003         
##     P-Value [Acc > NIR] : 1              
##                                          
##                   Kappa : 0.3495         
##                                          
##  Mcnemar's Test P-Value : 1.472e-09      
##                                          
##             Sensitivity : 0.78125        
##             Specificity : 0.80623        
##          Pos Pred Value : 0.30864        
##          Neg Pred Value : 0.97083        
##              Prevalence : 0.09969        
##          Detection Rate : 0.07788        
##    Detection Prevalence : 0.25234        
##       Balanced Accuracy : 0.79374        
##                                          
##        'Positive' Class : 1              
## 

ROC.

ROSE::roc.curve(valid_df$is_legendary, class_tr_rf_pred)

## Area under the curve (AUC): 0.794

Performance evaluation.

library(modelplotr)


scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_df"),
                                               dataset_labels = list("Validation data"),
                                               models = list("class_tr_rf"),
                                               model_labels = list("Random Forest"),
                                               target_column = "is_legendary",
                                               ntiles = 100)
## ... scoring caret model "class_tr_rf" on dataset "valid_df".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles, select_targetclass = "1")
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
## 
## No comparison specified, default values are used. 
## 
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "Random Forest.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".

Cumulative gains for kNN using weighted training data.

plot_cumgains(data = plot_input, highlight_ntile = 23, 
              custom_line_colors = "#DA5A0C")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative gains
## - When we select 23% with the highest probability according to model Random Forest, this selection holds 75% of all 1 cases in Validation data. 
##  
## 

Cumulative lift for kNN using weighted training data.

plot_cumlift(data = plot_input, highlight_ntile = 23,
             custom_line_colors = "#BEDA0C")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative lift
## - When we select 23% with the highest probability according to model Random Forest in Validation data, this selection for 1 cases is 3.3 times better than selecting without a model. 
##  
## 

7. kNN

train_norm <- train_df
valid_norm <- valid_df
names(train_norm)
## [1] "sp_attack"    "sp_defence"   "speed"        "type1"        "is_legendary"
library(caret)
norm_values <- preProcess(train_df[, -c(4,5)],
                          method = c("center",
                                     "scale"))
train_norm[, -c(4,5)] <- predict(norm_values,
                                train_df[, -c(4,5)])

head(train_norm)
##      sp_attack   sp_defence      speed   type1 is_legendary
## 574 -0.5003024 -0.176717360 -0.7288216 psychic            0
## 638  0.5765817  0.076597627  1.4176106   steel            1
## 608  0.7304222 -0.357656636 -0.3881181   ghost            0
## 123 -0.5003024  0.366100469  1.3153996     bug            0
## 540 -0.9618241 -0.357656636 -0.8310327     bug            0
## 654  0.5765817  0.004221916  0.2251483    fire            0
valid_norm[, -c(4,5)] <- predict(norm_values,
                                valid_df[, -c(4,5)])

head(valid_norm)
##     sp_attack sp_defence       speed type1 is_legendary
## 2   0.2689005  0.3661005 -0.21776634 grass            0
## 3   1.5611613  1.8136147  0.46364073 grass            0
## 4  -0.3464618 -0.7195352 -0.04741458  fire            0
## 5   0.2689005 -0.1767174  0.46364073  fire            0
## 12  0.5765817  0.3661005  0.12293719   bug            0
## 13 -1.5771864 -1.8051708 -0.55846988   bug            0
knn_model <- caret::knn3(is_legendary ~ ., data = train_norm, k = 15)
knn_model
## 15-nearest neighbor model
## Training set outcome distribution:
## 
##   0   1 
## 442  38

7.1 Predict training set

knn_pred_train <- predict(knn_model, newdata =
                            train_norm[,-c(5)],
                          type = "class")
head(knn_pred_train)
## [1] 0 0 0 0 0 0
## Levels: 0 1
confusionMatrix(knn_pred_train, as.factor(train_norm[, 5]),
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 439  23
##          1   3  15
##                                           
##                Accuracy : 0.9458          
##                  95% CI : (0.9216, 0.9643)
##     No Information Rate : 0.9208          
##     P-Value [Acc > NIR] : 0.0215752       
##                                           
##                   Kappa : 0.5108          
##                                           
##  Mcnemar's Test P-Value : 0.0001944       
##                                           
##             Sensitivity : 0.39474         
##             Specificity : 0.99321         
##          Pos Pred Value : 0.83333         
##          Neg Pred Value : 0.95022         
##              Prevalence : 0.07917         
##          Detection Rate : 0.03125         
##    Detection Prevalence : 0.03750         
##       Balanced Accuracy : 0.69397         
##                                           
##        'Positive' Class : 1               
## 

7.2 Predict validation set

knn_pred_valid <- predict(knn_model, 
                          newdata = valid_norm[, -c(5)],
                          type = "class")
head(knn_pred_valid)
## [1] 0 0 0 0 0 0
## Levels: 0 1
confusionMatrix(knn_pred_valid, as.factor(valid_norm[, 5]),
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 284  24
##          1   5   8
##                                           
##                Accuracy : 0.9097          
##                  95% CI : (0.8728, 0.9387)
##     No Information Rate : 0.9003          
##     P-Value [Acc > NIR] : 0.3278143       
##                                           
##                   Kappa : 0.3162          
##                                           
##  Mcnemar's Test P-Value : 0.0008302       
##                                           
##             Sensitivity : 0.25000         
##             Specificity : 0.98270         
##          Pos Pred Value : 0.61538         
##          Neg Pred Value : 0.92208         
##              Prevalence : 0.09969         
##          Detection Rate : 0.02492         
##    Detection Prevalence : 0.04050         
##       Balanced Accuracy : 0.61635         
##                                           
##        'Positive' Class : 1               
## 

7.3 Model evaluation

ROC for kNN on weighted data.

ROSE::roc.curve(valid_norm$is_legendary, knn_pred_valid)

## Area under the curve (AUC): 0.616

Performance evaluation.

library(modelplotr)


scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_norm"),
                                               dataset_labels = list("Validation data"),
                                               models = list("knn_model"),
                                               model_labels = list("kNN"),
                                               target_column = "is_legendary",
                                               ntiles = 10)
## ... scoring caret model "knn_model" on dataset "valid_norm".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles, select_targetclass = "1")
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
## 
## No comparison specified, default values are used. 
## 
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "kNN.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".

Cumulative gains for kNN using training data.

plot_cumgains(data = plot_input, highlight_ntile = 2)
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative gains
## - When we select 20% with the highest probability according to model kNN, this selection holds 78% of all 1 cases in Validation data. 
##  
## 

Cumulative lift for kNN using training data.

plot_cumlift(data = plot_input, highlight_ntile = 2,
             custom_line_colors = "#0022AA")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative lift
## - When we select 20% with the highest probability according to model kNN in Validation data, this selection for 1 cases is 3.9 times better than selecting without a model. 
##  
## 

8. Weighted data kNN

train_norm_2 <- train_df_rose
valid_norm_2 <- valid_df
names(train_norm_2)
## [1] "sp_attack"    "sp_defence"   "speed"        "type1"        "is_legendary"
library(caret)
norm_values_2 <- preProcess(train_df[, -c(4,5)],
                          method = c("center",
                                     "scale"))
train_norm_2[, -c(4,5)] <- predict(norm_values_2,
                                train_df[, -c(4,5)])

head(train_norm_2)
##    sp_attack   sp_defence      speed    type1 is_legendary
## 1 -0.5003024 -0.176717360 -0.7288216    water            0
## 2  0.5765817  0.076597627  1.4176106 fighting            0
## 3  0.7304222 -0.357656636 -0.3881181      bug            0
## 4 -0.5003024  0.366100469  1.3153996    water            0
## 5 -0.9618241 -0.357656636 -0.8310327    grass            0
## 6  0.5765817  0.004221916  0.2251483   normal            0
valid_norm_2[, -c(4,5)] <- predict(norm_values_2,
                                valid_df[, -c(4,5)])

head(valid_norm_2)
##     sp_attack sp_defence       speed type1 is_legendary
## 2   0.2689005  0.3661005 -0.21776634 grass            0
## 3   1.5611613  1.8136147  0.46364073 grass            0
## 4  -0.3464618 -0.7195352 -0.04741458  fire            0
## 5   0.2689005 -0.1767174  0.46364073  fire            0
## 12  0.5765817  0.3661005  0.12293719   bug            0
## 13 -1.5771864 -1.8051708 -0.55846988   bug            0
knn_model_2 <- caret::knn3(is_legendary ~ ., data = train_norm_2, k = 15)
knn_model_2
## 15-nearest neighbor model
## Training set outcome distribution:
## 
##   0   1 
## 231 249

8.1 Predict training set

knn_pred_train_2 <- predict(knn_model_2, newdata =
                            train_norm_2[,-c(5)],
                          type = "class")
head(knn_pred_train_2)
## [1] 1 0 0 0 0 0
## Levels: 0 1
confusionMatrix(knn_pred_train_2, as.factor(train_norm_2[, 5]),
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 184 100
##          1  47 149
##                                           
##                Accuracy : 0.6938          
##                  95% CI : (0.6504, 0.7347)
##     No Information Rate : 0.5188          
##     P-Value [Acc > NIR] : 4.877e-15       
##                                           
##                   Kappa : 0.3917          
##                                           
##  Mcnemar's Test P-Value : 1.796e-05       
##                                           
##             Sensitivity : 0.5984          
##             Specificity : 0.7965          
##          Pos Pred Value : 0.7602          
##          Neg Pred Value : 0.6479          
##              Prevalence : 0.5188          
##          Detection Rate : 0.3104          
##    Detection Prevalence : 0.4083          
##       Balanced Accuracy : 0.6975          
##                                           
##        'Positive' Class : 1               
## 

8.2 Predict validation set

knn_pred_valid_2 <- predict(knn_model_2, 
                          newdata = valid_norm_2[, -c(5)],
                          type = "class")
head(knn_pred_valid_2)
## [1] 0 0 1 1 0 1
## Levels: 0 1
confusionMatrix(knn_pred_valid_2, as.factor(valid_norm_2[, 5]),
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0 205  13
##          1  84  19
##                                           
##                Accuracy : 0.6978          
##                  95% CI : (0.6444, 0.7476)
##     No Information Rate : 0.9003          
##     P-Value [Acc > NIR] : 1               
##                                           
##                   Kappa : 0.1526          
##                                           
##  Mcnemar's Test P-Value : 1.182e-12       
##                                           
##             Sensitivity : 0.59375         
##             Specificity : 0.70934         
##          Pos Pred Value : 0.18447         
##          Neg Pred Value : 0.94037         
##              Prevalence : 0.09969         
##          Detection Rate : 0.05919         
##    Detection Prevalence : 0.32087         
##       Balanced Accuracy : 0.65155         
##                                           
##        'Positive' Class : 1               
## 

8.3 Model evaluation

ROC for kNN on weighted data.

ROSE::roc.curve(valid_norm_2$is_legendary, knn_pred_valid_2)

## Area under the curve (AUC): 0.652

Performance evaluation.

library(modelplotr)


scores_and_ntiles <- prepare_scores_and_ntiles(datasets = list("valid_norm_2"),
                                               dataset_labels = list("Validation data"),
                                               models = list("knn_model_2"),
                                               model_labels = list("kNN with balanced data"),
                                               target_column = "is_legendary",
                                               ntiles = 10)
## ... scoring caret model "knn_model_2" on dataset "valid_norm_2".
## Data preparation step 1 succeeded! Dataframe created.
plot_input <- plotting_scope(prepared_input = scores_and_ntiles, select_targetclass = "1")
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
## 
## No comparison specified, default values are used. 
## 
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation data" and model "kNN with balanced data.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".

Cumulative gains for kNN using weighted training data.

plot_cumgains(data = plot_input, highlight_ntile = 2)
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative gains
## - When we select 20% with the highest probability according to model kNN with balanced data, this selection holds 44% of all 1 cases in Validation data. 
##  
## 

Cumulative lift for kNN using weighted training data.

plot_cumlift(data = plot_input, highlight_ntile = 2,
             custom_line_colors = "#0022AA")
## Warning: Vectorized input to `element_text()` is not officially supported.
## Results may be unexpected or may change in future versions of ggplot2.
##  
## Plot annotation for plot: Cumulative lift
## - When we select 20% with the highest probability according to model kNN with balanced data in Validation data, this selection for 1 cases is 2.2 times better than selecting without a model. 
##  
##