Skip to content

Commit

Permalink
reorganize modelling folder and start vignettes
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeSydlowski authored and tanho63 committed Jan 16, 2022
1 parent 8e68373 commit 0b3bdc6
Show file tree
Hide file tree
Showing 24 changed files with 204 additions and 57 deletions.
21 changes: 21 additions & 0 deletions modelling/dalex_from_xgboost_ex.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# testing xgb dalex -------------------------------------------------------
#If you ever want to run these DALEX off of the xgboost instead
model_obj <- ffopportunity:::.load_model_objs("rushing_yards", version = "latest")

model_obj$model$params$objective <- "reg"

preprocessed_pbp <- ep_preprocess(nflreadr::load_pbp(2021))

rush_df <-
preprocessed_pbp$rush_df %>%
hardhat::forge(new_data = ., blueprint = model_obj$blueprint)

ep_load <- rush_df$predictors %>%
as.matrix()

rush_yards_explainer <-
DALEXtra::explain_xgboost(
model = model_obj$model,
data = ep_load,
# data = preprocessed_pbp$rush_df %>% as.matrix(),
y = preprocessed_pbp$rush_df$rushing_yards)
60 changes: 60 additions & 0 deletions modelling/dalex_plot_helpers.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# gg helpers ---------------------------------------------------------------

ggplot_imp <- function(...) {
obj <- list(...)
metric_name <- attr(obj[[1]], "loss_name")
metric_lab <- paste(metric_name,
"after permutations\n(higher indicates more important)")

full_vip <- bind_rows(obj) %>%
filter(variable != "_baseline_")

perm_vals <- full_vip %>%
filter(variable == "_full_model_") %>%
group_by(label) %>%
summarise(dropout_loss = mean(dropout_loss))

p <- full_vip %>%
filter(variable != "_full_model_") %>%
mutate(variable = fct_reorder(variable, dropout_loss)) %>%
ggplot(aes(dropout_loss, variable))
if(length(obj) > 1) {
p <- p +
facet_wrap(vars(label)) +
geom_vline(data = perm_vals, aes(xintercept = dropout_loss, color = label),
size = 1.4, lty = 2, alpha = 0.7) +
geom_boxplot(aes(color = label, fill = label), alpha = 0.2)
} else {
p <- p +
geom_vline(data = perm_vals, aes(xintercept = dropout_loss),
size = 1.4, lty = 2, alpha = 0.7) +
geom_boxplot(fill = "#91CBD765", alpha = 0.4)

}
p +
theme(legend.position = "none") +
labs(x = metric_lab,
y = NULL, fill = NULL, color = NULL)
}

ggplot_pdp <- function(obj, x) {

p <-
as_tibble(obj$agr_profiles) %>%
mutate(`_label_` = stringr::str_remove(`_label_`, "^[^_]*_")) %>%
ggplot(aes(`_x_`, `_yhat_`)) +
geom_line(data = as_tibble(obj$cp_profiles),
aes(x = {{ x }}, group = `_ids_`),
size = 0.5, alpha = 0.05, color = "gray50")
# facet_wrap(~`_groups_`)

num_colors <- n_distinct(obj$agr_profiles$`_label_`)

if (num_colors > 1) {
p <- p + geom_line(aes(color = `_label_`), size = 1.2, alpha = 0.8)
} else {
p <- p + geom_line(color = "midnightblue", size = 1.2, alpha = 0.8)
}

p
}
156 changes: 99 additions & 57 deletions modelling/save_dalex_rush_yards.R
Original file line number Diff line number Diff line change
@@ -1,73 +1,115 @@
library(dpylr)
library(DALEXtra)
library(here)
library(tidymodels)
library(tidyverse)

setwd(here())

source("./modelling/dalex_plot_helpers.R")

rush_yards_tidymodel <- readRDS("./modelling/fit_rush_yards.RDS")

preprocessed <- ffopportunity::ep_preprocess(nflreadr::load_pbp(2006:2021))

debug(DALEXtra::explain_tidymodels)
rush_yards_explainer <-
DALEXtra::explain_tidymodels(
rush_yards_tidymodel,
data = dplyr::select(preprocessed$rush_df, -rushing_yards),
y = preprocessed$rush_df$rushing_yards)

plot(feature_importance(rush_yards_explainer))

pdp_time <-
model_profile(
rush_yards_explainer,
variables = "rusher_age",
groups = "position"
)

plot(pdp_time)


# testing xgb dalex -------------------------------------------------------

.load_model_objs <- function(variable, version) {

cache_dir <- rappdirs::user_cache_dir("ffopportunity", "ffverse")

folder_path <- file.path(cache_dir,version)

model_path <- file.path(folder_path, paste0(variable,".xgb"))
blueprint_path <- file.path(folder_path, paste0(variable,".rds"))

stopifnot(file.exists(model_path), file.exists(blueprint_path))

model <- xgboost::xgb.load(model_path)
blueprint <- readRDS(blueprint_path)

return(list(model = model, blueprint = blueprint))
}



model_obj <- .load_model_objs("rushing_yards", version = "1.0.0")

model_obj$model$params$objective <- "reg"

preprocessed_pbp <- ep_preprocess(nflreadr::load_pbp(2021))

rush_df <-
preprocessed_pbp$rush_df %>%
hardhat::forge(new_data = ., blueprint = model_obj$blueprint)

ep_load <- rush_df$predictors %>%
dplyr::mutate(rushing_yards = preprocessed_pbp$rush_df$rushing_yards) %>%
as.matrix()

ep_load <- load_ep_pbp_rush(2018:2021) %>%
as.matrix()

rush_yards_explainer <-
DALEXtra::explain_xgboost(
model = model_obj$model,
data = ep_load,
y = ep_load[,14])
mod_parts <- DALEX::model_parts(rush_yards_explainer)

mod_parts %>%
group_by(variable) %>%
summarise(mean_dropout_loss = mean(dropout_loss, na.rm = TRUE)) %>%
ungroup() %>%
view()

png('./vignettes/plots/rush_yards_feat_imp.png', width = 1000, height = 592)
ggplot_imp(mod_parts %>%
filter(variable %in% c("_baseline_",
"_full_model_",
"yardline_100",
"run_gap",
"xpass",
"position",
"qb_dropback",
"ydstogo",
"half_seconds_remaining",
"vegas_wp",
"game_seconds_remaining",
"implied_total"))) +
tantastic::theme_uv() +
labs(title = "Feature Importance for Expected Rushing Yards",
subtitle = "Distance to Endzone, Position, Expected Pass Rate, and Run Gap are the most important factors")
dev.off()

pdp_yds <- DALEX::model_profile(
rush_yards_explainer,
groups = "run_gap",
N = 500,
variables = "yardline_100")

png('./vignettes/plots/rush_yards_pdp_yards.png', width = 1000, height = 592)
ggplot_pdp(pdp_yds, yardline_100) +
tantastic::theme_uv() +
labs(title = "How dp Expected Yards change with distance to the end zone?",
subtitle = "Outside rushes have higher expected yardage until you get inside the 5",
x = "Yards from End Zone",
y = "Expected Yardage",
color = "Run Gap") +
scale_x_continuous(breaks = seq(0,100,10)) +
scale_y_continuous(breaks = seq(0,10,2))
dev.off()

pdp_xpass <- DALEX::model_profile(
rush_yards_explainer,
groups = "position",
N = 500,
variables = "xpass")

png('./vignettes/plots/rush_yards_pdp_xpass.png', width = 1000, height = 592)
ggplot_pdp(pdp_xpass, xpass) +
tantastic::theme_uv() +
labs(title = "How dp Expected Yards change with likelihood to pass?",
subtitle = "All positions can expect higher rushing yards when the offense is more likely to pass",
x = "Expected Pass Rate",
y = "Expected Yardage",
color = "Position") +
scale_x_continuous(breaks = seq(0,1,.2)) +
scale_y_continuous(breaks = seq(0,10,2))
dev.off()

example_rush <- DALEX::predict_parts(rush_yards_explainer,
new_observation =
preprocessed$rush_df %>%
mutate(across(where(is.numeric), ~round(.x,3))) %>%
filter(play_id == 3553, game_id == "2021_18_SEA_ARI"))
# new_observation =
# preprocessed$rush_df %>%
# mutate(across(where(is.numeric), ~round(.x,3))) %>%
# slice_sample(n = 1))

png('./vignettes/plots/rush_yards_breakdown.png', width = 1000, height = 592)
plot(example_rush,
digits = 2,
vcolors = c("purple","darkgreen","black"),
max_features = 10) +
tantastic::theme_uv() +
theme(legend.position = "none",
panel.grid.minor.y = element_blank(),
strip.text = element_blank()) +
labs(title = "How does each component affect the predicted yards?",
subtitle = paste(
example_rush %>% filter(variable_name == "game_id") %>% pull(variable_value),
example_rush %>% filter(variable_name == "desc") %>% pull(variable_value))) +
scale_y_continuous(name = "Expected Yards", breaks = seq(4.2,4.9,0.1))
dev.off()

# model_perf <- DALEX::model_diagnostics(rush_yards_explainer)
#
# plot(model_perf, variable = "y", yvariable = "y_hat") +
# geom_abline(colour = "red", intercept = 0, slope = 1) +
# xlim(0,20)
#
# plot(model_perf, variable = "y_hat", yvariable = "residuals")
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Binary file added vignettes/plots/rush_yards_breakdown.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added vignettes/plots/rush_yards_feat_imp.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added vignettes/plots/rush_yards_pdp_xpass.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added vignettes/plots/rush_yards_pdp_yards.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
24 changes: 24 additions & 0 deletions vignettes/rushing_yards_vignette.Rmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
---
title: "Expected Rushing Yards"
author: "Joe Sydlowski"
date: "`r Sys.Date()`"
output: rmarkdown::html_vignette
vignette: >
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---

### Feature Importance

![feature importance](plots/rush_yards_feat_imp.png)
### Partial Dependence Plots

![pdp yards](plots/rush_yards_pdp_yards.png)

![pdp xpass](plots/rush_yards_pdp_xpass.png)

### Breakdown

![breakdown](plots/rush_yards_breakdown.png)


0 comments on commit 0b3bdc6

Please sign in to comment.