class: title-slide, center <span class="fa-stack fa-4x"> <i class="fa fa-circle fa-stack-2x" style="color: #ffffff;"></i> <strong class="fa-stack-1x" style="color:#E7553C;">2</strong> </span> # Classifying ## Introduction to Machine Learning in the Tidyverse ### Alison Hill · Garrett Grolemund #### [https://conf20-intro-ml.netlify.com/](https://conf20-intro-ml.netlify.com/) · [https://rstd.io/conf20-intro-ml](https://rstd.io/conf20-intro-ml) --- class: middle, center, frame # Goal of Machine Learning -- ## 🔨 construct .display[models] that -- ## 🎯 generate .display[accurate predictions] -- ## 🆕 for .display[future, yet-to-be-seen data] -- .footnote[Max Kuhn & Kjell Johnston, http://www.feat.engineering/] --- class: inverse, middle, center A model doesn't have to be a straight line... <img src="02-Classifying_files/figure-html/lm-fig-1.svg" width="504" style="display: block; margin: auto;" /> --- class: inverse, middle, center .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-1-1.svg" width="504" /> ] .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-2-1.svg" width="504" style="display: block; margin: auto;" /> ] --- class: middle, frame, center # Decision Trees To predict the outcome of a new data point: Uses rules learned from splits Each split maximizes information gain --- class: middle, center ![](https://media.giphy.com/media/gj4ZruUQUnpug/source.gif) --- <img src="02-Classifying_files/figure-html/unnamed-chunk-4-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/unnamed-chunk-5-1.png" width="504" style="display: block; margin: auto;" /> --- class: middle, center # Quiz How do assess predictions here? -- RMSE --- <img src="02-Classifying_files/figure-html/rt-test-resid-1.png" width="504" style="display: block; margin: auto;" /> --- class: middle, center .pull-left[ ### LM RMSE = 53884.78 ] -- .pull-right[ ### Tree RMSE = 61687.24 <img src="02-Classifying_files/figure-html/unnamed-chunk-8-1.png" width="504" /> ] --- class: inverse, middle, center .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-9-1.svg" width="504" /> ] .pull-right[ <img src="02-Classifying_files/figure-html/dt-fig-1.svg" width="504" style="display: block; margin: auto;" /> ] --- class: middle, center, inverse # What is a model? --- class: middle, center, frame # K Nearest Neighbors (KNN) To predict the outcome of a new data point: Find the K most similar old data points Take the average/mode/etc. outcome --- ```r knn_spec <- nearest_neighbor(neighbors = 5) %>% set_engine("kknn") %>% set_mode("regression") set.seed(100) fit_split(Sale_Price ~ ., model = knn_spec, split = ames_split) %>% collect_metrics() # A tibble: 2 x 3 .metric .estimator .estimate <chr> <chr> <dbl> 1 rmse standard 35870. 2 rsq standard 0.812 ``` --- --- <img src="02-Classifying_files/figure-html/knn-home1-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/knn-home2-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/knn-home2-10-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/knn-home2-25-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/knn-home2-50-1.png" width="504" style="display: block; margin: auto;" /> --- .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-13-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ <img src="02-Classifying_files/figure-html/underfit-knn-1.png" width="504" /> ] --- .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-14-1.png" width="504" style="display: block; margin: auto;" /> ] .pull-right[ <img src="02-Classifying_files/figure-html/fit-knn-1.png" width="504" /> ] --- exclude: true class: inverse .pull-left[ ![](figs/01-Predicting/lm-fig-1.svg) ] .pull-right[ ![](figs/01-Predicting/dt-fig-1.svg) ] --- class: inverse, middle, center exclude: true .pull-left[ ![](figs/01-Predicting/lm-fig-1.svg) ] -- .pull-right[ <img src="02-Classifying_files/figure-html/lr-fig-1.svg" width="504" /> ] .footnote[[Why is logistic regression considered a linear model?](https://sebastianraschka.com/faq/docs/logistic_regression_linear.html)] --- exclude: true class: inverse, middle, center .pull-left[ ![](01-Prediction/dt-fig-1.svg) ] -- .pull-right[ ] --- class: middle, center <img src="https://raw.githubusercontent.com/EmilHvitfeldt/blog/master/static/blog/2019-08-09-authorship-classification-with-tidymodels-and-textrecipes_files/figure-html/unnamed-chunk-18-1.png" width="70%" /> https://www.hvitfeldt.me/blog/authorship-classification-with-tidymodels-and-textrecipes/ --- class: middle, center <img src="https://www.kaylinpavlik.com/content/images/2019/12/dt-1.png" width="50%" /> https://www.kaylinpavlik.com/classifying-songs-genres/ --- class: middle, center <img src="images/sing-tree.png" width="607" /> [The Science of Singing Along](http://www.doc.gold.ac.uk/~mas03dm/papers/PawleyMullensiefen_Singalong_2012.pdf) --- class: middle, center <img src="https://a3.typepad.com/6a0105360ba1c6970c01b7c95c61fb970b-pi" width="40%" /> .footnote[[tweetbotornot2](https://github.com/mkearney/tweetbotornot2)] --- name: guess-the-animal class: middle, center, inverse <img src="http://www.atarimania.com/8bit/screens/guess_the_animal.gif" width="100%" /> --- class: your-turn # Your turn 1 Get in your teams. Have one member think of an animal; other members try to guess it by asking *yes/no* questions about it. Go! Write down how many questions it takes your team.
05
:
00
--- class: your-turn # Your turn 2 In your teams, discuss what qualities made for a good versus a bad question.
02
:
00
--- class: middle, center # What makes a good guesser? -- High information gain per question (can it fly?) -- Clear features (feathers vs. is it "small"?) -- Order matters --- class: inverse, middle, center # Congratulations! You just built a decision tree 🎉 --- background-image: url(images/aus-standard-animals.png) background-size: cover .footnote[[Australian Computing Academy](https://aca.edu.au/resources/decision-trees-classifying-animals/)] --- background-image: url(images/aus-standard-tree.png) background-size: cover .footnote[[Australian Computing Academy](https://aca.edu.au/resources/decision-trees-classifying-animals/)] --- background-image: url(images/annotated-tree/annotated-tree.001.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.002.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.003.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.004.png) background-size: cover --- background-image: url(images/annotated-tree/annotated-tree.005.png) background-size: cover --- background-image: url(images/bonsai-anatomy.jpg) background-size: cover --- background-image: url(images/bonsai-anatomy-flip.jpg) background-size: cover --- class: center, middle # Quiz Name that variable type! <img src="images/vartypes_quiz.png" width="50%" style="display: block; margin: auto;" /> --- <img src="images/vartypes_answers.png" width="80%" style="display: block; margin: auto;" /> --- <img src="images/vartypes_unicorn.jpeg" width="80%" style="display: block; margin: auto;" /> --- class: center, middle # Show of hands How many people have .display[fit] a logistic regression model with `glm()`? --- exclude: true --- class: middle, center, inverse .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-26-1.png" width="1695" /> ] .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-27-1.png" width="1695" /> ] --- .pull-left[ ```r uni_train %>% count(unicorn) # A tibble: 2 x 2 unicorn n <fct> <int> 1 0 100 2 1 50 ``` ] .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-29-1.png" width="504" /> ] --- <img src="02-Classifying_files/figure-html/unnamed-chunk-30-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/unnamed-chunk-32-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/unnamed-chunk-33-1.png" width="504" style="display: block; margin: auto;" /> --- <img src="02-Classifying_files/figure-html/unnamed-chunk-34-1.png" width="504" style="display: block; margin: auto;" /> --- class: middle, center <img src="02-Classifying_files/figure-html/unnamed-chunk-35-1.png" width="504" style="display: block; margin: auto;" /> --- ``` parsnip model object Fit time: 4ms n= 150 node), split, n, loss, yval, (yprob) * denotes terminal node 1) root 150 50 0 (0.6666667 0.3333333) 2) n_butterflies>=29.5 93 16 0 (0.8279570 0.1720430) * 3) n_butterflies< 29.5 57 23 1 (0.4035088 0.5964912) 6) n_kittens>=62.5 18 6 0 (0.6666667 0.3333333) * 7) n_kittens< 62.5 39 11 1 (0.2820513 0.7179487) * ``` --- class: middle, center <img src="02-Classifying_files/figure-html/unnamed-chunk-38-1.png" width="720" style="display: block; margin: auto;" /> --- ``` nn ..y 0 1 cover 2 0 [.83 .17] when n_butterflies >= 30 62% 6 0 [.67 .33] when n_butterflies < 30 & n_kittens >= 63 12% 7 1 [.28 .72] when n_butterflies < 30 & n_kittens < 63 26% ``` --- .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-41-1.png" width="504" style="display: block; margin: auto;" /> ] -- .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-42-1.png" width="504" style="display: block; margin: auto;" /> ] --- .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-43-1.png" width="504" /> ] -- .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-44-1.png" width="504" /> ] -- ### .center[🦋 split wins] --- .pull-left[ <img src="02-Classifying_files/figure-html/unnamed-chunk-45-1.png" width="504" /> ] -- .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-46-1.png" width="504" /> ] -- ### .center[🐱 split wins] --- class: middle, center # Sadly, we are not classifying unicorns today <img src="images/sad_unicorn.png" width="20%" style="display: block; margin: auto;" /> --- background-image: url(images/copyingandpasting-big.png) background-size: contain background-position: center class: middle, center --- background-image: url(images/so-dev-survey.png) background-size: contain background-position: center class: middle, center --- <img src="https://github.com/juliasilge/supervised-ML-case-studies-course/blob/master/img/remote_size.png?raw=true" width="80%" /> .footnote[[Julia Silge](https://supervised-ml-course.netlify.com/)] ??? Notes: The specific question we are going to address is what makes a developer more likely to work remotely. Developers can work in their company offices or they can work remotely, and it turns out that there are specific characteristics of developers, such as the size of the company that they work for, how much experience they have, or where in the world they live, that affect how likely they are to be a remote developer. --- # StackOverflow Data ```r glimpse(stackoverflow) Observations: 1,150 Variables: 21 Groups: remote [2] $ country <fct> United States, United States, Un… $ salary <dbl> 63750.00, 93000.00, 40625.00, 45… $ years_coded_job <int> 4, 9, 8, 3, 8, 12, 20, 17, 20, 4… $ open_source <dbl> 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1,… $ hobby <dbl> 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1,… $ company_size_number <dbl> 20, 1000, 10000, 1, 10, 100, 20,… $ remote <fct> Remote, Remote, Remote, Remote, … $ career_satisfaction <int> 8, 8, 5, 10, 8, 10, 9, 7, 8, 7, … $ data_scientist <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,… $ database_administrator <dbl> 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0,… $ desktop_applications_developer <dbl> 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0,… $ developer_with_stats_math_background <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,… $ dev_ops <dbl> 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,… $ embedded_developer <dbl> 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0,… $ graphic_designer <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… $ graphics_programming <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… $ machine_learning_specialist <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… $ mobile_developer <dbl> 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0,… $ quality_assurance_engineer <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… $ systems_administrator <dbl> 1, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0,… $ web_developer <dbl> 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1,… ``` --- # .center[`initial_split()`] .center["Splits" data randomly into a single testing and a single training set; extract `training` and `testing` sets from an rsplit] ```r set.seed(100) # Important! so_split <- initial_split(stackoverflow, strata = remote) so_train <- training(so_split) so_test <- testing(so_split) ``` --- class: your-turn # Your turn 3 Using the `so_train` and `so_test` datasets, how many individuals in our training set are remote? How about in the testing set?
02
:
00
--- ```r so_train %>% count(remote) # A tibble: 2 x 2 # Groups: remote [2] remote n <fct> <int> 1 Remote 432 2 Not remote 432 so_test %>% count(remote) # A tibble: 2 x 2 # Groups: remote [2] remote n <fct> <int> 1 Remote 143 2 Not remote 143 ``` --- .pull-left[ ```r so_train %>% count(remote) # A tibble: 2 x 2 # Groups: remote [2] remote n <fct> <int> 1 Remote 432 2 Not remote 432 so_test %>% count(remote) # A tibble: 2 x 2 # Groups: remote [2] remote n <fct> <int> 1 Remote 143 2 Not remote 143 ``` ] .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-54-1.png" width="504" /> ] --- class: inverse, middle, center # How would we do fit a tree with parsnip? <img src="https://raw.githubusercontent.com/rstudio/hex-stickers/master/PNG/parsnip.png" width="20%" /> --- class: middle, frame # .center[To specify a model with parsnip] .right-column[ 1\. Pick a .display[model] 2\. Set the .display[engine] 3\. Set the .display[mode] (if needed) ] --- class: middle, center # 1\. Pick a .display[model] All available models are listed at <tidymodels.github.io/parsnip/articles/articles/Models.html> <iframe src="https://tidymodels.github.io/parsnip/articles/articles/Models.html" width="504" height="400px"></iframe> --- class: middle, center # 2\. Set the .display[engine] We'll use `rpart` for building `C`lassification `A`nd `R`egression `T`rees ```r set_engine("rpart") ``` --- class: middle, center # 3\. Set the .display[mode] A character string for the model type (e.g. "classification" or "regression") ```r set_mode("classification") ``` --- class: middle, frame # .center[To specify a model with parsnip] ```r decision_tree() %>% set_engine("rpart") %>% set_mode("classification") ``` --- class: middle # .center[`fit_split()`] .center[.fade[Trains and tests a model with split data. Returns a tibble.]] ```r fit_split( formula, model, split ) ``` --- class: your-turn # Your turn 4 Fill in the blanks. Use the `tree_spec` model provided and `fit_split()` to: 1. Train a CART-based model with the formula = `remote ~ years_coded_job + salary`. 1. Predict remote status with the testing data. 1. Remind yourself what the output looks like! 1. Keep `set.seed(100)` at the start of your code.
03
:
00
--- ```r tree_spec <- decision_tree() %>% set_engine("rpart") %>% set_mode("classification") set.seed(100) # Important! tree_fit <- fit_split(remote ~ years_coded_job + salary, model = tree_spec, split = so_split) tree_fit # # Monte Carlo cross-validation (0.75/0.25) with 1 resamples # A tibble: 1 x 6 splits id .metrics .notes .predictions .workflow * <list> <chr> <list> <list> <list> <list> 1 <split [864/… train/test … <tibble [2 ×… <tibble [0… <tibble [286 ×… <workflo… ``` --- class: middle, center # Volunteer How we can expand a list column to see what is in it? -- `tidyr::unnest()` .footnote[https://tidyr.tidyverse.org/reference/unnest.html] --- ```r tree_fit %>% unnest(.predictions) # A tibble: 286 x 10 splits id .metrics .notes .pred_Remote `.pred_Not remo… .row .pred_class <list> <chr> <list> <list> <dbl> <dbl> <int> <fct> 1 <spli… trai… <tibble… <tibb… 0.687 0.313 2 Remote 2 <spli… trai… <tibble… <tibb… 0.687 0.313 12 Remote 3 <spli… trai… <tibble… <tibb… 0.368 0.632 15 Not remote 4 <spli… trai… <tibble… <tibb… 0.368 0.632 19 Not remote 5 <spli… trai… <tibble… <tibb… 0.687 0.313 23 Remote 6 <spli… trai… <tibble… <tibb… 0.368 0.632 28 Not remote 7 <spli… trai… <tibble… <tibb… 0.687 0.313 38 Remote 8 <spli… trai… <tibble… <tibb… 0.368 0.632 46 Not remote 9 <spli… trai… <tibble… <tibb… 0.687 0.313 53 Remote 10 <spli… trai… <tibble… <tibb… 0.368 0.632 56 Not remote # … with 276 more rows, and 2 more variables: remote <fct>, .workflow <list> ``` --- class: middle, center # `collect_predictions()` Unnest the predictions column from a tidymodels `fit_split()` ```r tree_fit %>% collect_predictions() ``` --- ```r tree_fit %>% collect_predictions() # A tibble: 286 x 6 id .pred_Remote `.pred_Not remote` .row .pred_class remote <chr> <dbl> <dbl> <int> <fct> <fct> 1 train/test split 0.687 0.313 2 Remote Remote 2 train/test split 0.687 0.313 12 Remote Remote 3 train/test split 0.368 0.632 15 Not remote Remote 4 train/test split 0.368 0.632 19 Not remote Remote 5 train/test split 0.687 0.313 23 Remote Remote 6 train/test split 0.368 0.632 28 Not remote Remote 7 train/test split 0.687 0.313 38 Remote Remote 8 train/test split 0.368 0.632 46 Not remote Remote 9 train/test split 0.687 0.313 53 Remote Remote 10 train/test split 0.368 0.632 56 Not remote Remote # … with 276 more rows ``` --- class: middle, center, frame # Goal of Machine Learning ## 🔨 construct .display[models] that .fade[ ## 🔮 generate accurate .display[predictions] ## 🆕 for .display[future, yet-to-be-seen data] ] --- class: middle, center, frame # Goal of Machine Learning .fade[ ## 🔨 construct .display[models] that ## 🔮 generate accurate .display[predictions] ] ## 🆕 for .display[future, yet-to-be-seen data] --- class: middle, center, frame # Goal of Machine Learning .fade[ ## 🔨 construct .display[models] that ] ## 🔮 generate accurate .display[predictions] .fade[ ## 🆕 for .display[future, yet-to-be-seen data] ] --- class: middle, center, frame # Goal of Machine Learning .fade[ ## 🔨 construct .display[models] that ] ## 🎯 generate .display[accurate predictions] .fade[ ## 🆕 for .display[future, yet-to-be-seen data] ] --- class: your-turn # Your turn 5 Use `collect_predictions()` and `count()` to count the number of individuals (i.e., rows) by their true and predicted remote status. In groups, answer the following questions: 1. How many predictions did we make? 2. How many times is "remote" status predicted? 3. How many respondents are actually remote? 4. How many predictions did we get right? *Hint: You can create a 2x2 table using* `count(var1, var2)`
05
:
00
--- ```r tree_fit %>% collect_predictions() %>% count(.pred_class, truth = remote) # A tibble: 4 x 3 .pred_class truth n <fct> <fct> <int> 1 Remote Remote 89 2 Remote Not remote 40 3 Not remote Remote 54 4 Not remote Not remote 103 ``` --- class: middle, center # `conf_mat()` Creates confusion matrix, or truth table, from a data frame with observed and predicted classes. ```r conf_mat(data, truth = remote, estimate = .pred_class) ``` --- ```r tree_fit %>% collect_predictions() %>% conf_mat(truth = remote, estimate = .pred_class) Truth Prediction Remote Not remote Remote 89 40 Not remote 54 103 ``` --- ```r tree_fit %>% collect_predictions() %>% conf_mat(truth = remote, estimate = .pred_class) %>% autoplot(type = "heatmap") ``` <img src="02-Classifying_files/figure-html/unnamed-chunk-70-1.png" width="40%" style="display: block; margin: auto;" /> --- class: middle, center # Confusion matrix <img src="images/conf-matrix/conf-matrix.001.jpeg" width="853" /> --- class: middle, center # Confusion matrix <img src="images/conf-matrix/conf-matrix.002.jpeg" width="853" /> --- class: middle, center # Confusion matrix <img src="images/conf-matrix/conf-matrix.003.jpeg" width="853" /> --- class: middle, center # Confusion matrix <img src="images/conf-matrix/conf-matrix.004.jpeg" width="853" /> --- class: middle, center # Accuracy <img src="images/conf-matrix/conf-matrix.007.jpeg" width="853" /> --- class: middle, center # Accuracy <img src="images/conf-matrix/conf-matrix.008.jpeg" width="853" /> --- class: middle, center # Accuracy <img src="images/conf-matrix/conf-matrix.009.jpeg" width="853" /> --- class: center background-image: url(images/conf-matrix/sens-spec.jpeg) background-size: 80% background-position: bottom # Sensitivity vs. Specificity --- name: sens class: center background-image: url(images/conf-matrix/sens.jpeg) background-size: 80% background-position: bottom # Sensitivity --- template: sens .pull-right[ True positive rate *Out of all **true positives**, how many did you predict right?* ] --- name: spec class: center background-image: url(images/conf-matrix/spec.jpeg) background-size: 80% background-position: bottom # Specificity --- template: spec .pull-left[ True negative rate *Out of all **true negatives**, how many did you predict right?* ] --- class: middle, center # `collect_metrics()` Unnest the metrics column from a tidymodels `fit_split()` ```r tree_fit %>% collect_metrics() ``` --- ```r tree_fit %>% collect_metrics() # A tibble: 2 x 3 .metric .estimator .estimate <chr> <chr> <dbl> 1 accuracy binary 0.671 2 roc_auc binary 0.678 ``` --- class: middle, center <iframe src="https://tidymodels.github.io/yardstick/articles/metric-types.html#metrics" width="504" height="400px"></iframe> <https://tidymodels.github.io/yardstick/articles/metric-types.html#metrics> --- class: middle # .center[`fit_split()`] .center[.fade[Trains and tests a model with split data. Returns a tibble.]] ```r fit_split( formula, model, split, * metrics = NULL ) ``` If `NULL`, `accuracy` and `roc_auc` when mode = "classification" --- class: middle, center # `metric_set()` A helper function for selecting yardstick metric functions. ```r metric_set(accuracy, sens, spec) ``` -- .footnote[Warning! Make sure you load `tidymodels` *after* `tidyverse`, as the `yardstick::spec` function has a name conflict.] --- class: middle # .center[`fit_split()`] .center[.fade[Trains and tests a model with split data. Returns a tibble.]] ```r fit_split( formula, model, split, * metrics = metric_set(accuracy, sens, spec) ) ``` --- ```r fit_split(remote ~ years_coded_job + salary, model = tree_spec, split = so_split, metrics = metric_set(accuracy, sens, spec)) %>% collect_metrics() # A tibble: 3 x 3 .metric .estimator .estimate <chr> <chr> <dbl> 1 accuracy binary 0.671 2 sens binary 0.622 3 spec binary 0.720 ``` --- ```r fit_split(remote ~ years_coded_job + salary, model = tree_spec, split = so_split, metrics = metric_set(accuracy, roc_auc)) %>% collect_metrics() # A tibble: 2 x 3 .metric .estimator .estimate <chr> <chr> <dbl> 1 accuracy binary 0.671 2 roc_auc binary 0.678 ``` --- ```r fit_split(remote ~ years_coded_job + salary, model = tree_spec, split = so_split) %>% collect_metrics() # A tibble: 2 x 3 .metric .estimator .estimate <chr> <chr> <dbl> 1 accuracy binary 0.671 2 roc_auc binary 0.678 ``` --- class: middle, center # `roc_curve()` Takes predictions from `fit_split()`. Returns a tibble with probabilities. ```r roc_curve(data, truth = remote, estimate = .pred_Remote) ``` Truth = .display[probability] of target response Estimate = .display[predicted] class --- ```r fit_split(remote ~ years_coded_job + salary, model = tree_spec, split = so_split) %>% collect_predictions() %>% roc_curve(truth = remote, estimate = .pred_Remote) # A tibble: 5 x 3 .threshold specificity sensitivity <dbl> <dbl> <dbl> 1 -Inf 0 1 2 0.368 0 1 3 0.6 0.720 0.622 4 0.687 0.762 0.573 5 Inf 1 0 ``` --- class: your-turn # Your turn 6 Use `collect_predictions()` and `roc_curve` to calculate the data needed to construct the full ROC curve. What is the threshold for achieving specificity > .75? --- ```r tree_fit <- fit_split(remote ~ years_coded_job + salary, model = tree_spec, split = so_split) tree_fit %>% collect_predictions() %>% roc_curve(truth = remote, estimate = .pred_Remote) # A tibble: 5 x 3 .threshold specificity sensitivity <dbl> <dbl> <dbl> 1 -Inf 0 1 2 0.368 0 1 3 0.6 0.720 0.622 4 0.687 0.762 0.573 5 Inf 1 0 ``` --- .pull-left[ ```r tree_fit %>% collect_predictions() %>% roc_curve(truth = remote, estimate = .pred_Remote) %>% ggplot(aes(x = 1 - specificity, y = sensitivity)) + geom_line( color = "midnightblue", size = 1.5 ) + geom_abline( lty = 2, alpha = 0.5, color = "gray50", size = 1.2 ) ``` ] .pull-right[ <img src="02-Classifying_files/figure-html/unnamed-chunk-90-1.png" width="504" /> ] --- ```r tree_fit %>% collect_predictions() %>% roc_curve(truth = remote, estimate = .pred_Remote) %>% autoplot() ``` <img src="02-Classifying_files/figure-html/unnamed-chunk-91-1.png" width="40%" style="display: block; margin: auto;" /> --- ## Area under the curve .pull-left[ <img src="images/roc-pretty-good.png" width="1037" /> ] .pull-right[ * AUC = 0.5: random guessing * AUC = 1: perfect classifer * In general AUC of above 0.8 considered "good" ] --- <img src="images/roc-guessing.png" width="80%" /> --- <img src="images/roc-perfect.png" width="80%" /> --- <img src="images/roc-bad.png" width="80%" /> --- <img src="images/roc-ok.png" width="80%" /> --- <img src="images/roc-pretty-good.png" width="80%" /> --- class: your-turn # Your turn 7 Add a `autoplot()` to visualize the ROC AUC.