Directions

Logistic regression to predict whether to sing a song on American Idol.

Data for demo

Back to the spellbook

1. Load Data

songs <- read.csv("american_idol_songs_v8.csv", header = TRUE)
head(songs, 10)
##    No                Song_Title                     Artiste Song_Avg_Rtg Year
## 1   1     Stuff Like That There                Bette Midler           95 1991
## 2   2                In A Dream                    Badlands           94 1991
## 3   3     Build Me Up Buttercup             The Foundations           93 1969
## 4   4  Hemorrhage (In My Hands)                        Fuel           92 2000
## 5   5                 Solitaire                  Carpenters           92 1974
## 6   6 Will You Love Me Tomorrow               The Shirelles           92 1960
## 7   7                Chandelier                         Sia           91 2014
## 8   8   Don't Rain On My Parade            Barbra Streisand           91 1964
## 9   9         A Whole New World Peabo Bryson & Regina Belle           90 1992
## 10 10      I Don't Hurt Anymore            Dinah Washington           90 1943
##    Avg_Song_Age Advance Bottom Elimination Expectation Artiste_Rating
## 1          11.0       1      0           0        20.5           55.5
## 2          14.0       1      0           0        24.4           94.0
## 3          34.0       1      0           0        26.2           93.0
## 4           6.0       1      0           0        29.4           92.0
## 5          29.0       1      0           0        25.2           68.5
## 6          51.0       1      0           0        24.9           92.0
## 7           3.7       1      0           0        24.5           66.6
## 8          40.0       1      0           0        15.7           62.9
## 9          11.0       1      0           0        28.7           90.0
## 10         63.0       1      0           0        29.3           90.0

Check data.

head(songs)
##   No                Song_Title         Artiste Song_Avg_Rtg Year Avg_Song_Age
## 1  1     Stuff Like That There    Bette Midler           95 1991           11
## 2  2                In A Dream        Badlands           94 1991           14
## 3  3     Build Me Up Buttercup The Foundations           93 1969           34
## 4  4  Hemorrhage (In My Hands)            Fuel           92 2000            6
## 5  5                 Solitaire      Carpenters           92 1974           29
## 6  6 Will You Love Me Tomorrow   The Shirelles           92 1960           51
##   Advance Bottom Elimination Expectation Artiste_Rating
## 1       1      0           0        20.5           55.5
## 2       1      0           0        24.4           94.0
## 3       1      0           0        26.2           93.0
## 4       1      0           0        29.4           92.0
## 5       1      0           0        25.2           68.5
## 6       1      0           0        24.9           92.0
str(songs)
## 'data.frame':    1626 obs. of  11 variables:
##  $ No            : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ Song_Title    : chr  "Stuff Like That There" "In A Dream" "Build Me Up Buttercup" "Hemorrhage (In My Hands)" ...
##  $ Artiste       : chr  "Bette Midler" "Badlands" "The Foundations" "Fuel" ...
##  $ Song_Avg_Rtg  : num  95 94 93 92 92 92 91 91 90 90 ...
##  $ Year          : int  1991 1991 1969 2000 1974 1960 2014 1964 1992 1943 ...
##  $ Avg_Song_Age  : num  11 14 34 6 29 51 3.7 40 11 63 ...
##  $ Advance       : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ Bottom        : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Elimination   : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Expectation   : num  20.5 24.4 26.2 29.4 25.2 24.9 24.5 15.7 28.7 29.3 ...
##  $ Artiste_Rating: num  55.5 94 93 92 68.5 92 66.6 62.9 90 90 ...
table(songs$Advance)
## 
##    0    1 
##  287 1339

2 PreProcessing

2.1 Filter Variables

Factorise.

songs$Advance <- as.factor(songs$Advance)
str(songs)
## 'data.frame':    1626 obs. of  11 variables:
##  $ No            : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ Song_Title    : chr  "Stuff Like That There" "In A Dream" "Build Me Up Buttercup" "Hemorrhage (In My Hands)" ...
##  $ Artiste       : chr  "Bette Midler" "Badlands" "The Foundations" "Fuel" ...
##  $ Song_Avg_Rtg  : num  95 94 93 92 92 92 91 91 90 90 ...
##  $ Year          : int  1991 1991 1969 2000 1974 1960 2014 1964 1992 1943 ...
##  $ Avg_Song_Age  : num  11 14 34 6 29 51 3.7 40 11 63 ...
##  $ Advance       : Factor w/ 2 levels "0","1": 2 2 2 2 2 2 2 2 2 2 ...
##  $ Bottom        : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Elimination   : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ Expectation   : num  20.5 24.4 26.2 29.4 25.2 24.9 24.5 15.7 28.7 29.3 ...
##  $ Artiste_Rating: num  55.5 94 93 92 68.5 92 66.6 62.9 90 90 ...
t(t(names(songs)))
##       [,1]            
##  [1,] "No"            
##  [2,] "Song_Title"    
##  [3,] "Artiste"       
##  [4,] "Song_Avg_Rtg"  
##  [5,] "Year"          
##  [6,] "Avg_Song_Age"  
##  [7,] "Advance"       
##  [8,] "Bottom"        
##  [9,] "Elimination"   
## [10,] "Expectation"   
## [11,] "Artiste_Rating"
songs <- songs[, c(4, 6, 10, 11, 7)]

head(songs)
##   Song_Avg_Rtg Avg_Song_Age Expectation Artiste_Rating Advance
## 1           95           11        20.5           55.5       1
## 2           94           14        24.4           94.0       1
## 3           93           34        26.2           93.0       1
## 4           92            6        29.4           92.0       1
## 5           92           29        25.2           68.5       1
## 6           92           51        24.9           92.0       1
str(songs)
## 'data.frame':    1626 obs. of  5 variables:
##  $ Song_Avg_Rtg  : num  95 94 93 92 92 92 91 91 90 90 ...
##  $ Avg_Song_Age  : num  11 14 34 6 29 51 3.7 40 11 63 ...
##  $ Expectation   : num  20.5 24.4 26.2 29.4 25.2 24.9 24.5 15.7 28.7 29.3 ...
##  $ Artiste_Rating: num  55.5 94 93 92 68.5 92 66.6 62.9 90 90 ...
##  $ Advance       : Factor w/ 2 levels "0","1": 2 2 2 2 2 2 2 2 2 2 ...

2.2 Training-Validation Split

Split the data into training and validation sets.

Set the seed using our favourite number :-)

set.seed(666)

Create the indices for the split This samples the row indices to split the data into training and validation.

train_index <- sample(1:nrow(songs), 0.7 * nrow(songs))
valid_index <- setdiff(1:nrow(songs), train_index)

Using the indices, create the training and validation sets. This is similar in principle to splitting a data frame by row.

train_df <- songs[train_index, ]
valid_df <- songs[valid_index, ]

It is a good habit to check after splitting.

nrow(train_df)
## [1] 1138
nrow(valid_df)
## [1] 488

2.3 Balance data

Balance.

library(ROSE)
## Loaded ROSE 0.0-4
train_rose <- ROSE(Advance ~ .,
                   data = train_df, seed = 666)$data

table(train_rose$Advance)
## 
##   1   0 
## 575 563
table(valid_df$Advance)
## 
##   0   1 
##  80 408

Check.

library(janitor)
## 
## Attaching package: 'janitor'
## The following objects are masked from 'package:stats':
## 
##     chisq.test, fisher.test
compare_df_cols(train_rose, valid_df)
##      column_name train_rose valid_df
## 1        Advance     factor   factor
## 2 Artiste_Rating    numeric  numeric
## 3   Avg_Song_Age    numeric  numeric
## 4    Expectation    numeric  numeric
## 5   Song_Avg_Rtg    numeric  numeric

3 Logistic Regression

library(caret)
## Loading required package: ggplot2
## Loading required package: lattice
logistic_reg <- train(Advance ~ Song_Avg_Rtg + Avg_Song_Age + Expectation +
                      Artiste_Rating,
                    data = train_rose, method = "glm")
summary(logistic_reg)
## 
## Call:
## NULL
## 
## Deviance Residuals: 
##     Min       1Q   Median       3Q      Max  
## -1.9542  -1.0769  -0.5598   1.0642   2.1199  
## 
## Coefficients:
##                 Estimate Std. Error z value Pr(>|z|)    
## (Intercept)     1.334674   0.238639   5.593 2.23e-08 ***
## Song_Avg_Rtg   -0.027385   0.004271  -6.412 1.44e-10 ***
## Avg_Song_Age   -0.011931   0.003441  -3.467 0.000525 ***
## Expectation    -0.010597   0.004673  -2.268 0.023347 *  
## Artiste_Rating  0.002262   0.004343   0.521 0.602533    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## (Dispersion parameter for binomial family taken to be 1)
## 
##     Null deviance: 1577.5  on 1137  degrees of freedom
## Residual deviance: 1451.1  on 1133  degrees of freedom
## AIC: 1461.1
## 
## Number of Fisher Scoring iterations: 4
varImp(logistic_reg)
## glm variable importance
## 
##                Overall
## Song_Avg_Rtg    100.00
## Avg_Song_Age     50.02
## Expectation      29.65
## Artiste_Rating    0.00

Predict the training set.

logistic_reg_pred_train <- predict(logistic_reg, 
                               newdata = train_rose, type = "raw")

head(logistic_reg_pred_train)
## [1] 1 1 1 0 0 0
## Levels: 1 0
logistic_reg_pred_train_prob <- predict(logistic_reg, 
                                   newdata = train_rose, type = "prob")

head(logistic_reg_pred_train_prob)
##           1         0
## 1 0.5523457 0.4476543
## 2 0.6512442 0.3487558
## 3 0.5381136 0.4618864
## 4 0.4581417 0.5418583
## 5 0.3161745 0.6838255
## 6 0.4666076 0.5333924

Predict the validation set.

logistic_reg_pred_valid <- predict(logistic_reg, 
                                   newdata = valid_df, type = "raw")

head(logistic_reg_pred_valid)
## [1] 1 1 1 1 1 1
## Levels: 1 0
logistic_reg_pred_valid_prob <- predict(logistic_reg, 
                                        newdata = valid_df, type = "prob")

head(logistic_reg_pred_valid_prob)
##            1         0
## 2  0.8103799 0.1896201
## 6  0.8640055 0.1359945
## 10 0.8795857 0.1204143
## 11 0.8359845 0.1640155
## 12 0.8366183 0.1633817
## 14 0.7831463 0.2168537

4. Model Evaluation

Confusion matrix. Training set.

confusionMatrix(as.factor(logistic_reg_pred_train), 
                train_rose$Advance, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   1   0
##          1 382 197
##          0 193 366
##                                           
##                Accuracy : 0.6573          
##                  95% CI : (0.6289, 0.6849)
##     No Information Rate : 0.5053          
##     P-Value [Acc > NIR] : <2e-16          
##                                           
##                   Kappa : 0.3145          
##                                           
##  Mcnemar's Test P-Value : 0.8793          
##                                           
##             Sensitivity : 0.6643          
##             Specificity : 0.6501          
##          Pos Pred Value : 0.6598          
##          Neg Pred Value : 0.6547          
##              Prevalence : 0.5053          
##          Detection Rate : 0.3357          
##    Detection Prevalence : 0.5088          
##       Balanced Accuracy : 0.6572          
##                                           
##        'Positive' Class : 1               
## 

Confusion matrix. Validation set.

confusionMatrix(as.factor(logistic_reg_pred_valid), 
                valid_df$Advance, positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0  62 153
##          1  18 255
##                                           
##                Accuracy : 0.6496          
##                  95% CI : (0.6054, 0.6919)
##     No Information Rate : 0.8361          
##     P-Value [Acc > NIR] : 1               
##                                           
##                   Kappa : 0.2383          
##                                           
##  Mcnemar's Test P-Value : <2e-16          
##                                           
##             Sensitivity : 0.6250          
##             Specificity : 0.7750          
##          Pos Pred Value : 0.9341          
##          Neg Pred Value : 0.2884          
##              Prevalence : 0.8361          
##          Detection Rate : 0.5225          
##    Detection Prevalence : 0.5594          
##       Balanced Accuracy : 0.7000          
##                                           
##        'Positive' Class : 1               
## 

ROC.

ROSE::roc.curve(valid_df$Advance, logistic_reg_pred_valid)

## Area under the curve (AUC): 0.700

4.1 Optimal Cutoff

library(pROC)
## Type 'citation("pROC")' for a citation.
## 
## Attaching package: 'pROC'
## The following objects are masked from 'package:stats':
## 
##     cov, smooth, var
rocCurve   <- pROC::roc(response = valid_df$Advance,
                        predictor = logistic_reg_pred_valid_prob[, 1])
## Setting levels: control = 0, case = 1
## Setting direction: controls < cases
rocCurve$auc
## Area under the curve: 0.7502
plot(rocCurve, print.thres = "best")

max(sensitivity + specificity)

head(coords(rocCurve))
##   threshold specificity sensitivity
## 1      -Inf      0.0000    1.000000
## 2 0.1718682      0.0125    1.000000
## 3 0.1816188      0.0250    1.000000
## 4 0.1907029      0.0375    1.000000
## 5 0.2026455      0.0500    1.000000
## 6 0.2147397      0.0500    0.997549
coords(rocCurve, x = "best")
##   threshold specificity sensitivity
## 1 0.5088966      0.8125   0.6029412
confusionMatrix(as.factor(ifelse(logistic_reg_pred_train_prob[, 1] > 0.509,
                                 "1", "0")), 
                as.factor(train_rose$Advance), 
                positive = "1")
## Warning in
## confusionMatrix.default(as.factor(ifelse(logistic_reg_pred_train_prob[, : Levels
## are not in the same order for reference and data. Refactoring data to match.
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   1   0
##          1 363 186
##          0 212 377
##                                          
##                Accuracy : 0.6503         
##                  95% CI : (0.6218, 0.678)
##     No Information Rate : 0.5053         
##     P-Value [Acc > NIR] : <2e-16         
##                                          
##                   Kappa : 0.3008         
##                                          
##  Mcnemar's Test P-Value : 0.2102         
##                                          
##             Sensitivity : 0.6313         
##             Specificity : 0.6696         
##          Pos Pred Value : 0.6612         
##          Neg Pred Value : 0.6401         
##              Prevalence : 0.5053         
##          Detection Rate : 0.3190         
##    Detection Prevalence : 0.4824         
##       Balanced Accuracy : 0.6505         
##                                          
##        'Positive' Class : 1              
## 
confusionMatrix(as.factor(ifelse(logistic_reg_pred_valid_prob[, 1] > 0.509,
                                 "1", "0")), 
                valid_df$Advance, 
                positive = "1")
## Confusion Matrix and Statistics
## 
##           Reference
## Prediction   0   1
##          0  65 162
##          1  15 246
##                                         
##                Accuracy : 0.6373        
##                  95% CI : (0.5929, 0.68)
##     No Information Rate : 0.8361        
##     P-Value [Acc > NIR] : 1             
##                                         
##                   Kappa : 0.239         
##                                         
##  Mcnemar's Test P-Value : <2e-16        
##                                         
##             Sensitivity : 0.6029        
##             Specificity : 0.8125        
##          Pos Pred Value : 0.9425        
##          Neg Pred Value : 0.2863        
##              Prevalence : 0.8361        
##          Detection Rate : 0.5041        
##    Detection Prevalence : 0.5348        
##       Balanced Accuracy : 0.7077        
##                                         
##        'Positive' Class : 1             
## 

5. Predict new songs

Import new songs.

New songs

new_songs <- read.csv("new_songs.csv", header = TRUE)
new_songs
##     No           Song_Title     Artiste Song_Avg_Rtg Year Avg_Song_Age Advance
## 1 6661 Walk With Me In Hell Lamb of God           96 2004           19      NA
## 2 6662          The Watcher  Arch Enemy           90 2022            1      NA
## 3 6663              Frantic   Metallica           28 2003           20      NA
##   Bottom Elimination Expectation Artiste_Rating
## 1     NA          NA          42            100
## 2     NA          NA          36            100
## 3     NA          NA          46            120
##                               Comments
## 1   Classic song from a legendary band
## 2 Fantastic song from a legendary band
## 3       Zzz song from a legendary band

Filter the variables.

names(new_songs)
##  [1] "No"             "Song_Title"     "Artiste"        "Song_Avg_Rtg"  
##  [5] "Year"           "Avg_Song_Age"   "Advance"        "Bottom"        
##  [9] "Elimination"    "Expectation"    "Artiste_Rating" "Comments"
new_songs_filter <- new_songs[, c(4, 6, 10, 11)]
new_songs_filter
##   Song_Avg_Rtg Avg_Song_Age Expectation Artiste_Rating
## 1           96           19          42            100
## 2           90            1          36            100
## 3           28           20          46            120

Predict.

logistic_reg_pred_new_songs <- predict(logistic_reg, 
                                   newdata = new_songs_filter, type = "raw")

head(logistic_reg_pred_new_songs)
## [1] 1 1 0
## Levels: 1 0
logistic_reg_pred_new_songs_prob <- predict(logistic_reg, 
                                   newdata = new_songs_filter, type = "prob")

head(logistic_reg_pred_new_songs_prob)
##           1         0
## 1 0.8506669 0.1493331
## 2 0.7853635 0.2146365
## 3 0.4717245 0.5282755

If I’m being honest… :-)

logistic_reg_pred_new_songs_df <- as.data.frame(logistic_reg_pred_new_songs)
names(logistic_reg_pred_new_songs_df)[1] <- "Prediction"


new_songs_prediction_df <- cbind(new_songs[c(2:3)], logistic_reg_pred_new_songs_df) 
new_songs_prediction_df
##             Song_Title     Artiste Prediction
## 1 Walk With Me In Hell Lamb of God          1
## 2          The Watcher  Arch Enemy          1
## 3              Frantic   Metallica          0

6. Gains and Lift Charts

library(modelplotr)
## Package modelplotr loaded! Happy model plotting!
scores_and_ntiles <- prepare_scores_and_ntiles(datasets = 
                                                 list("valid_df"),
                                               dataset_labels = 
                                                 list("Validation Data"),
                                               models = 
                                                 list("logistic_reg"),
                                               model_labels = 
                                                 list("Logistic regression"),
                                               target_column = "Advance",
                                               ntiles = 100)
## ... scoring caret model "logistic_reg" on dataset "valid_df".
## Data preparation step 1 succeeded! Dataframe created.
head(scores_and_ntiles)
##            model_label   dataset_label y_true    prob_1    prob_0 ntl_1 ntl_0
## 2  Logistic regression Validation Data      1 0.8103799 0.1896201     3    98
## 6  Logistic regression Validation Data      1 0.8640055 0.1359945     1   100
## 10 Logistic regression Validation Data      1 0.8795857 0.1204143     1   100
## 11 Logistic regression Validation Data      1 0.8359845 0.1640155     2    99
## 12 Logistic regression Validation Data      1 0.8366183 0.1633817     2    99
## 14 Logistic regression Validation Data      1 0.7831463 0.2168537     5    96
plot_input <- plotting_scope(prepared_input = scores_and_ntiles, 
                             select_targetclass = "1")
## Data preparation step 2 succeeded! Dataframe created.
## "prepared_input" aggregated...
## Data preparation step 3 succeeded! Dataframe created.
## 
## No comparison specified, default values are used. 
## 
## Single evaluation line will be plotted: Target value "1" plotted for dataset "Validation Data" and model "Logistic regression.
## "
## -> To compare models, specify: scope = "compare_models"
## -> To compare datasets, specify: scope = "compare_datasets"
## -> To compare target classes, specify: scope = "compare_targetclasses"
## -> To plot one line, do not specify scope or specify scope = "no_comparison".
head(plot_input)
##           scope         model_label   dataset_label target_class ntile neg pos
## 1 no_comparison Logistic regression Validation Data            1     0   0   0
## 2 no_comparison Logistic regression Validation Data            1     1   0   5
## 3 no_comparison Logistic regression Validation Data            1     2   0   5
## 4 no_comparison Logistic regression Validation Data            1     3   0   5
## 5 no_comparison Logistic regression Validation Data            1     4   1   4
## 6 no_comparison Logistic regression Validation Data            1     5   0   5
##   tot pct negtot postot tottot    pcttot cumneg cumpos cumtot cumpct
## 1   0  NA     NA     NA     NA        NA      0      0      0     NA
## 2   5 1.0     80    408    488 0.8360656      0      5      5   1.00
## 3   5 1.0     80    408    488 0.8360656      0     10     10   1.00
## 4   5 1.0     80    408    488 0.8360656      0     15     15   1.00
## 5   5 0.8     80    408    488 0.8360656      1     19     20   0.95
## 6   5 1.0     80    408    488 0.8360656      1     24     25   0.96
##          gain    cumgain gain_ref   gain_opt      lift  cumlift cumlift_ref
## 1 0.000000000 0.00000000     0.00 0.00000000        NA       NA           1
## 2 0.012254902 0.01225490     0.01 0.01225490 1.1960784 1.196078           1
## 3 0.012254902 0.02450980     0.02 0.02450980 1.1960784 1.196078           1
## 4 0.012254902 0.03676471     0.03 0.03676471 1.1960784 1.196078           1
## 5 0.009803922 0.04656863     0.04 0.04901961 0.9568627 1.136275           1
## 6 0.012254902 0.05882353     0.05 0.06127451 1.1960784 1.148235           1
##   legend
## 1      1
## 2      1
## 3      1
## 4      1
## 5      1
## 6      1

Cumulative gains for logistic regression.

plot_cumgains(data = plot_input, highlight_ntile = 66,
              custom_line_colors = "#1E9C33")
##  
## Plot annotation for plot: Cumulative gains
## - When we select 66% with the highest probability according to model Logistic regression, this selection holds 72% of all 1 cases in Validation Data. 
##  
## 

Cumulative lift for logistic regression.

plot_cumlift(data = plot_input, highlight_ntile = 66,
              custom_line_colors = "#6642f6")
##  
## Plot annotation for plot: Cumulative lift
## - When we select 66% with the highest probability according to model Logistic regression in Validation Data, this selection for 1 cases is 1.1 times better than selecting without a model. 
##  
## 

Response plot for logistic regression.

plot_response(data = plot_input)

Cumulative response plot for logistic regression.

plot_cumresponse(data = plot_input)

Multiplot for logistic regression.

plot_multiplot(data = plot_input)