Listing Mortality Prediction

Author

Michael D. Porter

Published

July 23, 2024

Data

  1. Load training and test data
Show the code
data_train = read_rds(file.path(dir_data, "model_data_train.rds")) 
data_test = read_rds(file.path(dir_data, "model_data_test.rds")) 
  1. Clean Data
    • Remove outliers in Creatinine (and eGFR).
    • Remove outliers in median_refusals
    • impute missing height, weight, and bmi using Age and Gender
    • Remove lengthy Functional Status descriptions
Show the code
#: table of sizes from data (by AGE and GENDER)
impute_size_data = bind_rows(data_train, data_test) %>% 
  group_by(AGE, GENDER) %>% 
  summarize(
    WEIGHT_KG = median(WEIGHT_KG, na.rm=TRUE), 
    HEIGHT_CM = median(HEIGHT_CM, na.rm=TRUE), 
    .groups = "drop"
  )
#: function to impute missing height and weight (using AGE, GENDER)
impute_size <- function(var, AGE, GENDER){
  var = match.arg(var, c("HEIGHT_CM", "WEIGHT_KG"))
  X = tibble(AGE, GENDER) %>% 
    left_join(impute_size_data, by = c("AGE", "GENDER"))
  if(var == "HEIGHT_CM") X$HEIGHT_CM else X$WEIGHT_KG
}

# outliers in CREAT and median refusals
clean_data <- function(df){
  df %>% 
    mutate(
      outlier = MOST_RCNT_CREAT > 8,
      eGFR = ifelse(outlier, eGFR*MOST_RCNT_CREAT/8, eGFR),
      MOST_RCNT_CREAT = pmin(MOST_RCNT_CREAT, 8)
    ) %>% 
    select(-outlier) %>% 
    #: outliers in median refusals
    mutate(
      median_refusals = pmin(median_refusals, 20)
    ) %>% 
    #: impute height, weight, and bmi
    mutate(
      HEIGHT_CM = coalesce(HEIGHT_CM, impute_size("HEIGHT_CM", AGE, GENDER)),
      WEIGHT_KG = coalesce(WEIGHT_KG, impute_size("WEIGHT_KG", AGE, GENDER)),
      BMI = coalesce(BMI, WEIGHT_KG / (HEIGHT_CM/100)^2),
      BSA = coalesce(BSA, sqrt(HEIGHT_CM * WEIGHT_KG / 3600)),
    ) %>% 
    #: remove Functional Status Descriptions; only keep %
    mutate(
      FUNC_STAT_CAND_REG = str_replace(FUNC_STAT_CAND_REG, "(\\d+%).+", "\\1")
    )
}

Update data.

Show the code
data_train = data_train %>% clean_data()
data_test = data_test %>% clean_data()
  1. Create 10-fold cv of training data.
Show the code
set.seed(2024)
cv_folds = rsample::vfold_cv(data_train, v = 10, strata = outcome)

Predictive Modeling

Create baseline preprocessing recipe and set predictor variables. All models start with this recipe.

  • Sets predictors and outcome variables
  • Converts Diabetes to {Yes, No}
  • Converts Functional Status to numeric, adds indicator for baby and unknown
  • Cleans eGFR: removes outliers, code based on kidney risk, binary for eGFR < 60
  • Creates binary for Albumin < 3
Show the code
library(tidymodels)

base_rec = 
  #: Set formula
  recipe(outcome ~ ., data = head(data_train)) %>%
  # step_naomit(outcome) %>%       # remove rows with missing outcome
  step_mutate(outcome = factor(outcome, levels = c(1, 0)), skip=TRUE) %>% 
  #: Remove variables from all models
  step_rm(WL_ID_CODE) %>% 
  step_rm(
    matches("DONCRIT_.+_AGE"), 
    matches("DONCRIT_.+_HGT"),  
    matches("DONCRIT_.+_WGT"), 
    matches("DONCRIT_.+_MILE"), 
    matches("DONCRIT_"), # remove all DONCRIT variables. They are primarily
    # recorded for a few UNOS REGIONS. 
    # --- not clinical
    CITIZENSHIP, 
    HEMODYNAMICS_CO, # lots of missing
    CEREB_VASC, 
    # ---
    CAND_DIAG_LISTING, # use CAND_DIAG instead
    CAND_DIAG_CODE,    # use CAND_DIAG instead 
    MOST_RCNT_CREAT,   # use eGFR instead
    VAD_DEVICE_TY_TCR, # not enough info
    WL_DT,             # use LIST_YR for temporal information
    #------------------------------------- Substitutes
    # LC_effect,
    p_refusals,          # use median_refusals instead
    #-------------------------------------
    # LISTING_CTR_CODE, # Let individual models choose
    REGION,           # Some regions only have 1-2 centers
    LIFE_SUPPORT_OTHER,
    PGE_TCR,
    LIST_YR,
  ) %>% 
  #: Additional cleaning
  step_mutate(
    # convert Diabetes to Yes = 1, No = 0
    DIAB = case_match(DIAB, 
                      "None" ~ 0L, 
                      "Unknown" ~ 0L, 
                      .default = 1L)
  ) %>% 
  #: Convert Functional status to numeric; add indicators for missing and baby
  step_mutate(
    FUNC_STAT_NUM = str_extract(FUNC_STAT_CAND_REG, "(\\d+)%", group = 1) %>%
      as.numeric() %>% coalesce(0),
    FUNC_STAT_UNKNOWN = ifelse(FUNC_STAT_CAND_REG == "Unknown", 1L, 0L),
    CAND_UNDER_1 = ifelse(FUNC_STAT_CAND_REG == "Not Applicable (patient < 1 year old)", 1L, 0L),
  ) %>% 
  step_rm(FUNC_STAT_CAND_REG) %>% # remove the original variable
  #: Cutoffs for eGFR (Kidney Disease) and Albumin (Nutrition)
  step_mutate(
    eGFR = pmin(eGFR, 250),  # fix outliers for non-tree models
    eGFR_CODED = case_when(
      eGFR > 120 ~ 0,
      eGFR >= 90 ~ 1, 
      eGFR >= 60 ~ 2,
      eGFR >= 45 ~ 3, 
      eGFR >= 30 ~ 3.5, 
      eGFR >= 15 ~ 4, 
      eGFR <  15 ~ 5, 
      .default = 0),  # if missing, assume eGFR is good
    eGFR_UNDER_60 = ifelse(eGFR < 60, 1, 0) %>% 
      coalesce(0), # if missing, assume eGFR is good (above 60)
    ALBUM_UNDER_3 = ifelse(TOT_SERUM_ALBUM < 3, 1, 0) %>% 
      coalesce(0) # if missing, assume Albumin is good (above 3)
  )
  #: outliers in median refusals
  # step_mutate(median_refusals = pmin(median_refusals, 20)) 

Logistic Regression

1. Tidymodels specification

Lasso logistic regression model.

  1. Remove LISTING_CTR_CODE
  2. Add new missing indicator feature for all variables with missing
  3. Convert Functional Status to number {0, 10, …, 100}. Add indicator for Unknown status.
  4. One-hot encode all categorical predictors. For variables with {Yes, No, Unknown}, only keep the Yes column. This lumps the Unknown with No. 
  5. Truncate median_refusals to 20.
  6. Impute all missing values with median.
  7. Create new features by coding eGFR into stages of chronic kidney failure.
  8. Create binary TOT_SERUM_ALBUM < 3 indicator.
  9. Add polynomial terms for the numeric features.
Show the code
library(tidymodels)

# Model specification: Lasso penalized logistic regression
lasso_spec = 
  logistic_reg() %>%
  set_engine("glmnet") %>%
  set_args(
    mixture = 1,    # 1 = lasso, 0 = ridge
    penalty = tune()
  ) 
  
# Recipe:
lasso_rec = 
  base_rec %>% 
  #: Remove additional variables for this model
  step_rm(
    # LISTING_CTR_CODE,
  ) %>%
  #: Add additional variables to represent missing predictors
  step_indicate_na(all_predictors()) %>% 
  #: Convert categorical predictors to dummy 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% # one-hot
  step_dummy(all_ordered_predictors(), one_hot = TRUE) %>% # one-hot
  step_rm(ends_with("_No")) %>%       # removes the No level (Yes is binary)
  step_rm(ends_with("_Unknown")) %>%  # removes the Unknown level (Yes is binary).
                                # This effectively treats "Unknown" same as "No"
  #: Impute missing values
  step_impute_median(all_numeric_predictors()) %>% 
  #: Add polynomial terms
  step_poly(
    AGE, WEIGHT_KG, HEIGHT_CM, BMI, BSA,
    eGFR,
    # TOT_SERUM_ALBUM,
    FUNC_STAT_NUM,
    # pedhrtx_prev_yr, 
    # median_refusals,
    # LC_effect
    degree = 2,
  ) %>% 
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors()) %>%      
  #: Scale all predictors for variable importance scoring
  step_scale(all_numeric_predictors()) 

# Workflow:
lasso_wflow = 
  workflow(preprocessor = lasso_rec, spec = lasso_spec)

2. Preprocesses

Show the code
# Pre-process training data
lasso_rec_fitted = prep(lasso_rec, data_train)
X_train_lasso = bake(lasso_rec_fitted, new_data = NULL, 
               all_predictors(), composition = "matrix")
Y_train_lasso = bake(lasso_rec_fitted, new_data = NULL, 
               all_outcomes()) %>% pull()

3. Tune lambda using cross-validation

Tuning the \(\lambda\) (or penalty) parameter using cross-validation to maximize the AUC.

Show the code
# Create 10 fold cv indices
# set.seed(100)
# folds = sample(rep(1:10, length = nrow(X_train_lasso)))

folds = cv_folds %>% broom::tidy() %>% filter(Data == "Assessment") %>%
  arrange(Row) %>% mutate(Fold = readr::parse_number(Fold)) %>% pull(Fold)

# Run 10-fold CV on training data to estimate lambda
library(glmnet)
cv_fit = cv.glmnet(
  X_train_lasso, Y_train_lasso,
  family = "binomial",
  alpha = 1,    # lasso
  relax = FALSE,
  foldid = folds,
  type.measure = "auc")

4. CV performance

Show the code
# plot(cv_fit)
cv_fit

Call:  cv.glmnet(x = X_train_lasso, y = Y_train_lasso, type.measure = "auc",      foldid = folds, relax = FALSE, family = "binomial", alpha = 1) 

Measure: AUC 

      Lambda Index Measure      SE Nonzero
min 0.003094    31  0.7478 0.01300      58
1se 0.015045    14  0.7370 0.01325      11

5. Test Performance

Show the code
X_test_lasso = bake(lasso_rec_fitted, new_data = data_test, 
              all_predictors(), composition = "matrix")

data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    lasso.min =  1-predict(cv_fit, X_test_lasso, type = "response", 
                           s = "lambda.min")[,1], 
    lasso.1se = 1-predict(cv_fit, X_test_lasso, type = "response", 
                          s = "lambda.1se")[,1],
    lasso.unpenalized = 1-predict(cv_fit, X_test_lasso, type = "response", 
                          s = 0)[,1],
  ) %>% 
  pivot_longer(starts_with("lasso")) %>% group_by(name) %>% 
  reframe(calc_metrics(outcome, value)) %>% 
  arrange(-auc) %>% 
  mutate_all(\(x) digits(x, 3))

6. Variable Importance

Show the code
vip::vi(cv_fit, lambda = cv_fit$lambda.min) %>% 
  filter(Importance > 1E-8) 
Show the code
# Plot at (1se or min) lambda 
vip::vi(cv_fit, lambda = cv_fit$lambda.min) %>%   # cv_fit$lambda.1se
  filter(Importance > 1E-8) %>% 
  mutate(
    Variable = Variable %>% 
      str_replace("_poly_1", " (linear)") %>% 
      str_replace("_poly_2", " (quadratic)") %>% 
      str_replace("_poly_3", " (cubic)") %>% 
      # str_replace_all("_", " ") %>% 
      str_replace("(CAND_DIAG)_(.+)", "\\1: \\2") %>% 
      str_replace_all("\\.", " ") %>% 
      str_wrap(30, whitespace_only = FALSE),
    Variable = fct_reorder(Variable, abs(Importance)),
    Importance = ifelse(Sign == "NEG", -Importance, Importance),
    ) %>% 
  ggplot(aes(Importance, Variable, color = Sign)) + 
  geom_point() + 
  geom_segment(aes(xend = 0, yend = Variable)) + 
  scale_color_brewer(type = "qual", palette = 2) + 
  labs(y=  "", title = "Predicting Waitlist Survival")

7. Additive Effects Plot

Show the code
#: Get final fit with tuned lambda/penalty
lasso_fit = lasso_wflow %>% 
  finalize_workflow(tibble(penalty = cv_fit$lambda.min)) %>% 
  fit(data_train)
Show the code
get_raw_variable_names <- function(x){
  x %>% 
      str_remove("na_ind_") %>% 
      str_remove("_poly_\\d") %>% 
      str_remove("_Unknown") %>% 
      str_remove("_Yes") %>% 
      str_replace("FUNC_STAT_NUM", "FUNC_STAT_CAND_REG") %>% 
      str_replace("ABO(_.+)", "ABO") %>% 
      str_replace("CAND_DIAG_CODE(_.+)", "CAND_DIAG_CODE") %>% 
      {ifelse(. == "CAND_DIAG_CODE", ., str_replace(.,"CAND_DIAG(_.+)", "CAND_DIAG"))} %>%
      str_replace("LIFE_SUPPORT_CAND_REG(_.+)", "LIFE_SUPPORT_CAND_REG") %>% 
      str_replace("eGFR(_.+)", "eGFR") %>% # since eGFR_CODED isn't in data
      str_replace("RACE(_.+)", "RACE") %>% 
      str_replace("GENDER(_.+)", "GENDER") %>%
      str_replace("STATUS(_.+)", "STATUS") %>%
      str_replace("LISTING_CTR_CODE(_X.+)", "LISTING_CTR_CODE") %>% 
      str_replace("ALBUM_UNDER_3",  "TOT_SERUM_ALBUM") %>% 
      str_replace("CAND_UNDER_1",  "AGE") %>% 
      str_replace("REGION(_X.+)", "REGION") %>% 
      str_replace("LISTING_CTR_CODE(_X.+)", "LISTING_CTR_CODE")
}
Show the code
imp_features = vip::vi(cv_fit, lambda = cv_fit$lambda.min) %>% 
  filter(Importance > 1E-8) %>% 
  mutate(var_raw = get_raw_variable_names(Variable)) %>% 
  mutate(.by = var_raw, total_importance = mean(Importance)) %>% 
  arrange(-total_importance, -Importance)

imp_vars = imp_features %>% distinct(var_raw) %>% pull()

walk(imp_vars, plot_additive_effects, model = lasso_fit)

GAM

1. Tidymodels recipe

Show the code
# ?details_gen_additive_mod_mgcv
library(mgcv)
# Recipe:
gam_rec = 
  base_rec %>% 
  #: Remove additional variables for this model
  # step_rm(
  #   LISTING_CTR_CODE,
  # ) %>% 
  #: Add additional variables to represent missing predictors
  step_indicate_na(all_predictors()) %>% 
  #: Impute missing values
  step_impute_median(all_numeric_predictors()) %>% 
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors()) %>%      
  #: Scale all predictors for variable importance scoring
  step_scale(all_numeric_predictors()) 

#: train recipe and create X matrix
rec_fitted = prep(gam_rec, data_train)
data_train_gam = bake(rec_fitted, new_data = NULL)

2. Fit GAM

Show the code
fit_gam = mgcv::gam(outcome ~ 
                s(HEIGHT_CM) + 
                s(AGE) + 
                # s(AGE,HEIGHT_CM) + 
                # s(AGE, WEIGHT_KG) + 
                # s(BMI) + s(WEIGHT_KG) + 
                # s(GENDER, bs = "re") + 
                # s(HEIGHT_CM_PERC) +
                # s(WEIGHT_KG_PERC) + 
                s(RACE, bs = "re") + 
                # s(CITIZENSHIP, bs = "re") + 
              # s(STATUS, bs = "re") + 
                s(ABO, bs = "re") + 
                s(LIFE_SUPPORT_CAND_REG, bs = "re") +
                # s(LIFE_SUPPORT_OTHER, bs = "re") +
                # s(PGE_TCR, bs = "re") +
                ECMO_CAND_REG + #s(ECMO_CAND_REG, bs = "re") +
                s(VAD_CAND_REG, bs = "re") +
                VENTILATOR_CAND_REG +  #s(VENTILATOR_CAND_REG, bs = "re") +
               # s(FUNC_STAT_CAND_REG, bs = "re") +
                s(FUNC_STAT_NUM) + 
                s(FUNC_STAT_UNKNOWN, bs = "re") + 
                s(CAND_UNDER_1, bs = "re") + 
                # s(WL_OTHER_ORG, bs = "re") +
                # s(CEREB_VASC, bs = "re") +
                # s(DIAB, bs = "re") +
                s(DIALYSIS_CAND, bs = "re") +
                # s(HEMODYNAMICS_CO,  bs = "re") +
                # s(IMPL_DEFIBRIL, bs = "re") +
                s(INOTROP_VASO_CO_REG, bs = "re") +
                # s(INOTROPES_TCR, bs = "re") +
                # s(MOST_RCNT_CREAT) +
                eGFR_CODED + 
                # I(eGFR < 60) + 
                # s(eGFR) +
                I(TOT_SERUM_ALBUM < 3) +
                # s(TOT_SERUM_ALBUM) +
                s(CAND_DIAG, bs = "re") +
                # s(WL_OTHER_ORG, bs = "re") + 
               s(LISTING_CTR_CODE, bs = "re") +
                # s(LIST_YR) +
                # s(REGION, bs = "re") +
                # s(LC_effect, k=4) + 
                # s(median_wait_days_1A, k = 3) + 
                # s(median_refusals, k = 3) + 
                # s(pedhrtx_prev_yr, k = 3) + 
                LC_effect +  
                median_refusals +  
                pedhrtx_prev_yr,
              # select=TRUE,
              method = "GCV.Cp", 
              data = data_train_gam, 
              family = binomial())

3. Variable Importance

Show the code
summary(fit_gam)

Family: binomial 
Link function: logit 

Formula:
outcome ~ s(HEIGHT_CM) + s(AGE) + s(RACE, bs = "re") + s(ABO, 
    bs = "re") + s(LIFE_SUPPORT_CAND_REG, bs = "re") + ECMO_CAND_REG + 
    s(VAD_CAND_REG, bs = "re") + VENTILATOR_CAND_REG + s(FUNC_STAT_NUM) + 
    s(FUNC_STAT_UNKNOWN, bs = "re") + s(CAND_UNDER_1, bs = "re") + 
    s(DIALYSIS_CAND, bs = "re") + s(INOTROP_VASO_CO_REG, bs = "re") + 
    eGFR_CODED + I(TOT_SERUM_ALBUM < 3) + s(CAND_DIAG, bs = "re") + 
    s(LISTING_CTR_CODE, bs = "re") + LC_effect + median_refusals + 
    pedhrtx_prev_yr

Parametric coefficients:
                           Estimate Std. Error z value Pr(>|z|)    
(Intercept)                 3.87164    5.69559   0.680    0.497    
ECMO_CAND_REG              -0.16982    0.04295  -3.954 7.69e-05 ***
VENTILATOR_CAND_REG        -0.23719    0.05141  -4.614 3.96e-06 ***
eGFR_CODED                 -0.07255    0.05804  -1.250    0.211    
I(TOT_SERUM_ALBUM < 3)TRUE -0.01011    0.22197  -0.046    0.964    
LC_effect                   0.09804    0.07272   1.348    0.178    
median_refusals            -0.28426    0.06617  -4.296 1.74e-05 ***
pedhrtx_prev_yr             0.08437    0.06836   1.234    0.217    
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Approximate significance of smooth terms:
                               edf Ref.df Chi.sq  p-value    
s(HEIGHT_CM)             3.145e+00  4.013 29.518 6.65e-06 ***
s(AGE)                   1.002e+00  1.003 11.465 0.000726 ***
s(RACE)                  1.525e+00  4.000  6.729 0.009994 ** 
s(ABO)                   1.028e+00  3.000  1.739 0.180734    
s(LIFE_SUPPORT_CAND_REG) 3.241e-04  2.000  0.000 0.494820    
s(VAD_CAND_REG)          1.135e-04  1.000  0.000 0.821894    
s(FUNC_STAT_NUM)         5.983e+00  7.111  5.097 0.656946    
s(FUNC_STAT_UNKNOWN)     5.932e-04  1.000  0.000 0.513660    
s(CAND_UNDER_1)          6.774e-01  1.000  3.017 0.028265 *  
s(DIALYSIS_CAND)         1.807e+00  2.000 13.702 0.000854 ***
s(INOTROP_VASO_CO_REG)   1.841e+00  2.000  9.991 0.008351 ** 
s(CAND_DIAG)             3.840e+00  7.000 60.813  < 2e-16 ***
s(LISTING_CTR_CODE)      1.823e+01 86.000 28.701 0.012485 *  
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

R-sq.(adj) =  0.105   Deviance explained = 14.8%
UBRE = -0.43727  Scale est. = 1         n = 4523
Show the code
bind_rows(tidy(fit_gam), tidy(fit_gam, parametric=TRUE)) %>% 
  arrange(p.value) %>% 
  transmute(
    var = str_replace(term, "s\\((.+)\\)", "\\1"), 
    # term, 
    edf = coalesce(edf, 1) %>% digits(3), 
    p.value = digits(p.value,4)
  ) 

4. Test Performance

Show the code
data_test_gam = bake(rec_fitted, new_data = data_test)

data_test_gam %>%
  transmute(
    outcome = factor(outcome, c(1,0)), 
    p_gam =  1-predict(fit_gam, ., type = "response") %>% as.numeric
  ) %>% 
  reframe(calc_metrics(outcome, p_gam)) %>% 
  mutate_all(\(x) digits(x, 3))

5. Partial Dependence Plots

Show the code
library(gratia)
gam_plot_data = bind_rows(
    # Smooth effects
    gratia::smooth_estimates(fit_gam, unnest = FALSE) %>% 
    mutate(var = map_chr(data, \(x) tail(colnames(x),1)), .before = 1) %>% 
    select(-.smooth), 
    # unpenalized
    gratia::parametric_effects(fit_gam, unnest = FALSE) %>% 
    mutate(.type = "unpenalized") %>% 
    rename(var = .term)
  ) %>% 
  # Add edf and p.value
  left_join(
    bind_rows(tidy(fit_gam), tidy(fit_gam, parametric=TRUE)) %>% 
      arrange(p.value) %>% 
      transmute(
        var = str_replace(term, "s\\((.+)\\)", "\\1") %>% 
          str_remove("TRUE"), 
        edf = coalesce(edf, 1) %>% digits(3), 
        p.value = digits(p.value,4)
      ), 
    by = "var"
  ) %>% 
  arrange(p.value)
Show the code
plot_gam_effects <- function(select = 1){
  df = gam_plot_data %>% 
    slice(!!select) %>% 
    unnest(data)
  var_name = df$var[1]
  if(df$.type[1] == "unpenalized"){
    df = df %>% mutate( !!var_name := .value, .estimate = .partial)
  }
  categorical = is.character(df[[var_name]]) | is.factor(df[[var_name]]) | nrow(df) < 6
  if(categorical) {
    plt = plot_categorical_effects(df[[var_name]], df[[".estimate"]], xlab = var_name)
  } else{
    plt = plot_numeric_effects(df[[var_name]], df[[".estimate"]], xlab = var_name)
  }
  
  print(
    plt +
    labs(y = "partial effect") +
    scale_y_continuous(breaks = seq(-10, 10, by = .25)) +
    coord_cartesian(ylim = c(-1,1))
  )  
  
}
Show the code
walk(1:nrow(gam_plot_data), \(i) plot_gam_effects(i))

Random Forest

1. Tidymodels specification

Show the code
# Model specification: Random Forest
rf_spec = 
  rand_forest() %>%
  set_mode("classification") %>%
  set_engine("ranger", 
    seed = 2024, 
    importance = "impurity", #"none", 
    num.threads = 8
  ) %>% 
  set_args(
    mtry = tune(), 
    trees = 2000, 
    min_n = 2
  ) 

# Recipe:
rf_rec = 
  base_rec %>% 
  #: Remove additional variables for this model
  step_rm(
    # LISTING_CTR_CODE,
  ) %>% 
  #: Add additional variables to represent missing predictors
  step_indicate_na(all_predictors()) %>% 
  #: Convert categorical predictors to dummy 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% # one-hot
  # step_dummy(all_ordered_predictors(), one_hot = TRUE) %>% # one-hot
  step_rm(ends_with("_No")) %>%       # removes the No level (Yes is binary)
  step_rm(ends_with("_Unknown")) %>%  # removes the Unknown level (Yes is binary).
                                # This effectively treats "Unknown" same as "No"
  #: Impute missing values
  step_impute_median(all_numeric_predictors()) %>% 
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors())

# Workflow:
rf_wflow = 
  workflow(preprocessor = rf_rec, spec = rf_spec)

2. Tune mtry

Use OOB observations for tuning instead of cross-validation.

Show the code
# pre-tuning: use oob brier score to seed the full cross-val grid search
#             this is used to speed along tune_grid() 
oob_brier <- function(mtry){
  # according to help(ranger), brier metric is used for oob error
  fit = rf_wflow %>% 
    finalize_workflow(list(mtry = mtry)) %>% 
    fit(data_train) 
  tibble(
    mtry, 
    oob_error = fit %>% extract_fit_engine() %>% pluck("prediction.error")
  )
}
num_cols = prep(rf_rec, data_train)$term_info %>% nrow()    # number of features
mtry_max = min(num_cols, 2*sqrt(num_cols)) %>% floor() # max mtry to try
mtry_oob_grid = seq(1, mtry_max, length=50) %>% floor() %>% unique()
oob_perf = map_df(mtry_oob_grid, oob_brier) # get oob performance
mtry_grid = oob_perf %>% 
  slice_min(oob_error, n = 5) %>% # keep best 5 mtry values
  select(mtry)

3. CV performance

Show the code
#: select from oob
rf_tune = mtry_grid %>% slice_min(mtry_grid, n=1, with_ties = FALSE)

set.seed(1000)
fit_resamples(
  object = rf_wflow %>% finalize_workflow(rf_tune),
  resamples = cv_folds,
  metrics = metric_set(roc_auc, brier_class, mn_log_loss, accuracy), 
  control = control_grid(verbose=FALSE)  
) %>% collect_metrics()

4. Test performance

Fit random forest

Show the code
rf_fit = rf_wflow %>% 
  finalize_workflow(rf_tune) %>% 
  fit(data_train)

Test performance

Show the code
data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    p_rf =  predict(rf_fit, data_test, type = "prob")$.pred_1
  ) %>% 
  reframe(calc_metrics(outcome, p_rf)) %>% mutate_all(\(x) digits(x, 3))

5. Variable Importance

Show the code
rf_fit
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
14 Recipe Steps

• step_mutate()
• step_rm()
• step_rm()
• step_mutate()
• step_mutate()
• step_rm()
• step_mutate()
• step_rm()
• step_indicate_na()
• step_dummy()
• ...
• and 4 more steps.

── Model ─────────────────────────────────────────────────────────────────────────────────
Ranger result

Call:
 ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~3,      x), num.trees = ~2000, min.node.size = min_rows(~2, x), seed = ~2024,      importance = ~"impurity", num.threads = ~8, verbose = FALSE,      probability = TRUE) 

Type:                             Probability estimation 
Number of trees:                  2000 
Sample size:                      4523 
Number of independent variables:  142 
Mtry:                             3 
Target node size:                 2 
Variable importance mode:         impurity 
Splitrule:                        gini 
OOB prediction error (Brier s.):  0.08213557 
Show the code
rf_fit %>% extract_fit_parsnip() %>% vip::vi() %>% filter(Importance > 1E-8)

XGBoost

1. Tidymodels specification

Fixing learn_rate(eta) to 0.10, sample_size = 0.80, and tuning tree_depth and trees. Fixing learn_rate and tuning the number of trees is good for efficiency due to the multi-predict capabilities.

  1. Remove LISTING_CTR_CODE
  2. One-hot encode all nominal predictors. For variables with {Yes, No, Unknown}, only keep the Yes column. This lumps the Unknown with No. 
  3. Let xgboost handle missing values
Show the code
library(xgboost)
library(tidymodels)

# Model specification: XGBoost
bt_spec = 
  boost_tree() %>% 
  set_mode("classification") %>% 
  set_engine("xgboost",  nthread = 8) %>% 
  set_args(
    trees = tune(), 
    tree_depth = tune(),
    learn_rate = 0.10, # fixed 
    sample_size = .80, # fixed
  ) 

# Recipe:
## Let xgboost handle missing values internally; does not impute
bt_rec = 
  base_rec %>% 
  step_rm(
    # LISTING_CTR_CODE,
  ) %>% 
  # step_dummy(LISTING_CTR_CODE, one_hot = TRUE) %>% 
  #: Convert categorical predictors to dummy 
  step_dummy(all_nominal_predictors(), one_hot = TRUE) %>% # one-hot
  # step_dummy(all_ordered_predictors(), one_hot = TRUE) %>% # one-hot
  step_rm(ends_with("_No")) %>%       # removes the No level (Yes is binary)
  step_rm(ends_with("_Unknown")) %>%  # removes the Unknown level (Yes is binary).
                                # This effectively treats "Unknown" same as "No"
  #: Remove all zero variance predictors (e.g., from step_indicate_na() )
  step_zv(all_predictors())    

# Workflow:
bt_wflow = 
  workflow(preprocessor = bt_rec, spec = bt_spec)

2. Tuning

Show the code
# Tuning grid. Use fewer trees from larger depth
bt_grid = 
  bind_rows(
    tibble(trees = seq(25, 300, by = 5), tree_depth = 1), 
    tibble(trees = seq(10, 150, by = 5), tree_depth = 2),
    tibble(trees = seq(10, 75, by = 5), tree_depth = 3), 
    tibble(trees = seq(10, 50, by = 5), tree_depth = 4), 
  )
# expand_grid(trees = seq(25, 250, by = 10), tree_depth = 1:4)

#: don't use bayes here since it won't exploit the multi-predict efficiency
set.seed(1000)
tune_res = tune_grid(
  object = bt_wflow,
  resamples = cv_folds,
  grid = bt_grid,
  metrics = metric_set(roc_auc, brier_class, mn_log_loss, accuracy), 
  control = control_grid(verbose=FALSE)
)

3. CV performance

Show the code
tune_res %>% show_best(metric = "roc_auc")
Show the code
tune_res %>% collect_metrics() %>% filter(.metric == "roc_auc") %>% 
  arrange(trees) %>% 
  ggplot(aes(trees, mean, color = factor(tree_depth))) + 
  geom_point() + 
  geom_line() + 
  labs(x = "Number of Trees", y = "Avg AUC", color = "tree depth")

4. Final model fit

Show the code
# Final model fit
(bt_tune = tune_res %>% select_best(metric = "roc_auc"))
Show the code
set.seed(1234)
bt_fit = finalize_workflow(bt_wflow, bt_tune) %>% fit(data_train)

5. Variable Importance

Show the code
bt_fit
══ Workflow [trained] ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()

── Preprocessor ──────────────────────────────────────────────────────────────────────────
12 Recipe Steps

• step_mutate()
• step_rm()
• step_rm()
• step_mutate()
• step_mutate()
• step_rm()
• step_mutate()
• step_rm()
• step_dummy()
• step_rm()
• ...
• and 2 more steps.

── Model ─────────────────────────────────────────────────────────────────────────────────
##### xgb.Booster
raw: 156 Kb 
call:
  xgboost::xgb.train(params = list(eta = 0.1, max_depth = 1, gamma = 0, 
    colsample_bytree = 1, colsample_bynode = 1, min_child_weight = 1, 
    subsample = 0.8), data = x$data, nrounds = 205, watchlist = x$watchlist, 
    verbose = 0, nthread = 8, objective = "binary:logistic")
params (as set within xgb.train):
  eta = "0.1", max_depth = "1", gamma = "0", colsample_bytree = "1", colsample_bynode = "1", min_child_weight = "1", subsample = "0.8", nthread = "8", objective = "binary:logistic", validate_parameters = "TRUE"
xgb.attributes:
  niter
callbacks:
  cb.evaluation.log()
# of features: 138 
niter: 205
nfeatures : 138 
evaluation_log:
    iter training_logloss
       1        0.6309374
       2        0.5796070
---                      
     204        0.2706729
     205        0.2706168
Show the code
bt_fit %>% extract_fit_parsnip() %>% vip::vi() %>% filter(Importance > 1E-8)
Show the code
data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    p_bt =  predict(bt_fit, data_test, type = "prob")$.pred_1
  ) %>% 
  reframe(calc_metrics(outcome, p_bt)) %>% mutate_all(\(x) digits(x, 3))

6. SHAP Dependence Plot

Get SHAP values

Show the code
data_train_xgb = bt_rec %>% 
  prep(data_train) %>% 
  bake(data_train, all_predictors(), composition = "matrix")

bt_shap = predict(
  bt_fit %>% extract_fit_engine(), 
  data_train_xgb, 
  predcontrib = TRUE) %>% 
  as_tibble()

SHAP Importance (Mean absolute deviation)

Show the code
bt_shap_imp_features = bt_shap %>% select(-BIAS) %>% 
  map_dbl(\(x) mean(abs(x))) %>% enframe(name = "feature", value="shap_imp") %>% 
  filter(shap_imp > 0) %>% arrange(-shap_imp)

bt_shap_imp_features

SHAP dependence plots (NOTE: not showing categorical)

Show the code
plot_shap_effects <- function(var, shap=bt_shap, data = data_train_xgb){
  var = rlang::ensym(var)
  df = data_train_xgb %>% as_tibble() %>% 
    select({{var}}) %>% 
    mutate(
      shap = -pull(shap, {{var}})
    ) 
  rug_data = df %>% count({{var}}) %>% rename(x = {{var}})
  df = distinct(df)
  n_bks = nrow(rug_data)
  categorical = is.character(df[[1]]) | is.factor(df[[1]])
  if(n_bks > 6 & !categorical ){
    plt = plot_numeric_effects(df[[1]], df[[2]], var, rug_data)
  } else{
    plt = plot_categorical_effects(df[[1]], df[[2]], var, rug_data)
  }
  
  print(
    plt +
    labs(y = "partial effect") +
    scale_y_continuous(breaks = seq(-10, 10, by = .25)) +
    coord_cartesian(ylim = c(-1,1))
  )
}


# plot_shap_effects("CAND_DIAG_Congenital.Heart.Disease.With.Surgery")

walk(bt_shap_imp_features$feature, plot_shap_effects)

Details on eGFR impact

Show the code
# eGFR deep dive
plot_shap_effects(eGFR) + scale_x_continuous(breaks = seq(0, 1000, by = 15)) +
  geom_vline(xintercept = c(15, 30, 45, 60, 90), color = "purple", alpha = .25)

Show the code
data_train_xgb %>% as_tibble() %>% 
  select(eGFR) %>% 
  mutate(
    shap = -pull(bt_shap, eGFR)    
  ) %>% 
  distinct() %>% arrange(eGFR) %>% 
  filter(lag(shap, default=0) != shap)

Details on Albumin impact

Show the code
# ALBUM deep dive
plt = plot_shap_effects(TOT_SERUM_ALBUM) 

Show the code
plt + scale_x_continuous(breaks = seq(0, 10, by=1))

Show the code
data_train_xgb %>% as_tibble() %>% 
  select(TOT_SERUM_ALBUM) %>% 
  mutate(
    shap = -pull(bt_shap, TOT_SERUM_ALBUM)    
  ) %>% 
  distinct() %>% arrange(TOT_SERUM_ALBUM) %>% 
  filter(lag(shap, default=0) != shap)

7. Partial Dependence Plots

Show the code
imp_features = bt_fit %>% 
  extract_fit_parsnip() %>% vip::vi() %>% filter(Importance > 1E-8)

vars = imp_features %>% 
  pull(Variable) %>% 
  str_remove("na_ind_") %>% #str_remove("_X.+") %>% 
  str_remove("_Unknown") %>% str_remove("_Yes") %>% 
  str_replace("FUNC_STAT_NUM", "FUNC_STAT_CAND_REG") %>% 
  str_replace("ABO(_.+)", "ABO") %>% 
  str_replace("CAND_DIAG_CODE(_.+)", "CAND_DIAG_CODE") %>% 
  {ifelse(. == "CAND_DIAG_CODE", ., str_replace(.,"CAND_DIAG(_.+)", "CAND_DIAG"))} %>%
  str_replace("LIFE_SUPPORT_CAND_REG(_.+)", "LIFE_SUPPORT_CAND_REG") %>% 
str_replace("eGFR_CODED", "eGFR") %>% # since eGFR_CODED isn't in data
  str_replace("RACE(_.+)", "RACE") %>% 
  str_replace("LISTING_CTR_CODE(_X.+)", "LISTING_CTR_CODE") %>% 
  unique() 


walk(vars, plot_additive_effects, model = bt_fit)

Listing Center Only

Show the code
base_survival = 1-mean(data_train$outcome)
k = 5 # shrinkage/laplace parameter
LC_train = data_train %>% 
  group_by(LISTING_CTR_CODE) %>% 
  summarize(
    n = n(),
    p = 1 - mean(outcome),
    p_survival = (p * n  + base_survival * k) / (n+k)
  )

predict_LC <- function(LISTING_CTR_CODE){
  LC_train$p_survival[LISTING_CTR_CODE] %>% 
    coalesce(base_survival)
}

Status Only

Show the code
base_survival = 1-mean(data_train$outcome)
k = 5 # shrinkage/laplace parameter
STATUS_train = data_train %>% 
  group_by(STATUS) %>% 
  summarize(
    n = n(),
    p = 1 - mean(outcome),
    p_survival = (p * n  + base_survival * k) / (n+k)
  )

predict_STATUS <- function(STATUS){
  STATUS_train$p_survival[STATUS] %>% 
    coalesce(base_survival)
}

Model Comparison

Test Performance

Show the code
data_test %>% 
  transmute(
    WL_ID_CODE, 
    outcome = factor(outcome, c(1,0)), 
    lasso =  1-predict(cv_fit, X_test_lasso, type = "response", 
                           s = "lambda.min")[,1], 
    xgboost =  predict(bt_fit, data_test, type = "prob")$.pred_1, 
    RF =  predict(rf_fit, data_test, type = "prob")$.pred_1,
    GAM = 1-predict(fit_gam, data_test_gam, type = "response") %>% as.numeric,
    ensemble = (lasso + xgboost + RF + GAM) / 4,
    LC = 1-predict_LC(data_test$LISTING_CTR_CODE),
    STATUS = 1-predict_STATUS(data_test$STATUS)
  ) %>% 
  pivot_longer(c(lasso, xgboost, RF, GAM, ensemble, LC)) %>% 
  group_by(name) %>% 
  reframe(calc_metrics(outcome, value)) %>% 
  arrange(-auc) %>% 
  mutate_all(\(x) digits(x, 3))

Additive Effects

Show the code
features = prep(base_rec, data_train) %>% 
  bake(new_data=NULL, all_predictors()) %>% 
  colnames() %>% intersect(colnames(data_train))


plot_multi_effects <- function(var){
  var = rlang::ensym(var)
  df = bind_rows(
    lasso = get_additive_effects(var, model = lasso_fit, data=data_train),
    xgboost = get_additive_effects(var, model = bt_fit, data=data_train), 
    .id = "model"
  ) %>% 
    rename(x = !!var, y = eta)
  
  rug_data = df %>% filter(!is.na(n)) 
  n_bks = nrow(rug_data)
  categorical = is.character(df$x) | is.factor(df$x)
  if(n_bks > 6 & !categorical ){
    # plt = plot_numeric_effects(df[[1]], df[[2]], var, rug_data)
    plt = ggplot(df) + 
    geom_hline(yintercept = 0, color = "orange") +
    geom_line(aes(x, y, color=model))  
  } else{
    # plt = plot_categorical_effects(df[[1]], df[[2]], var, rug_data)
    rug_data = rug_data %>% mutate(x=as.factor(x))
    plt = ggplot(rug_data) + 
    geom_hline(yintercept = 0, color = "orange") +
    geom_col(aes(x, y, fill = model), width = 1/3, 
             position = "dodge") + 
    scale_x_discrete(label = scales::label_wrap(15)) 
  }

    print(
      plt + 
      geom_rug(data = rug_data %>% distinct(x,n), 
       aes(x, linewidth = n), 
       show.legend = FALSE,
       sides = "b", alpha = .25) + 
      labs(x = as.character(var), y = "partial effect") +
      scale_y_continuous(breaks = seq(-10, 10, by = .25)) +
      scale_color_brewer(type = "qual", palette = "Dark2") +
      scale_fill_brewer(type = "qual", palette = "Dark2") +
      coord_cartesian(ylim = c(-1,1)) + 
      theme(
        legend.position = c(0.02, 0.98),
        legend.justification = c("left", "top"),
        legend.title=element_blank()
      )
    )
}


walk(features, plot_multi_effects)

Feature Importance

The feature importance metric is the mean absolute effect (like shap importance).

Show the code
# Mean Absolute Effect
feature_importance <- function(var){
  var = rlang::ensym(var)
  df = bind_rows(
    lasso = get_additive_effects(var, model = lasso_fit, data=data_train),
    xgboost = get_additive_effects(var, model = bt_fit, data=data_train), 
    .id = "model"
  ) %>% 
    rename(x = !!var, y = eta) %>% 
    filter(!is.na(n)) %>% 
    group_by(model) %>% 
    summarize(Importance = sum(n*abs(y))/sum(n))
}

map(set_names(features), feature_importance) %>% 
  bind_rows(.id = "Variable") %>% 
  mutate(Variable = fct_reorder(Variable, Importance, .fun = "mean")) %>% 
  ggplot(aes(x=Importance, y=Variable, fill = model)) + 
  geom_col(position = "dodge") + 
  scale_fill_brewer(type = "qual", palette = "Dark2")

Center Level Performance

This shows the predictive bias: test survival - predicted survival, by listing center. The top centers have the largest volume (n). Removed centers with \(n \le 2\).

As an example, the third from the top center (19468) has a bias of about -0.10. This means that the actual survival in this center was 10% lower than the models’ predicted.

While I’m not too concerned about the bias in the low volume centers (bottom on plot), there does appear to be modest unaccounted center effects.

Show the code
test_perf = data_test %>% 
  transmute(
    WL_ID_CODE, 
    LISTING_CTR_CODE,
    REGION,
    outcome = factor(outcome, c(1,0)), 
    lasso =  1-predict(cv_fit, X_test_lasso, type = "response", 
                           s = "lambda.min")[,1], 
    xgboost =  predict(bt_fit, data_test, type = "prob")$.pred_1, 
    RF =  predict(rf_fit, data_test, type = "prob")$.pred_1,
    GAM = 1-predict(fit_gam, data_test_gam, type = "response") %>% as.numeric,
    ensemble = (lasso + xgboost + RF + GAM) / 4,
  ) %>% 
  pivot_longer(c(lasso, xgboost, RF, GAM, ensemble)) %>% 
  mutate(value = 1-value) %>%  # convert to survival predictions
  group_by(name, LISTING_CTR_CODE, REGION) %>% 
  reframe(n = n(), n0 = sum(outcome == 0), p = mean(outcome == 0),
          mean_pred = mean(value), 
          calc_metrics(outcome, 1-value))

test_perf %>% 
  filter(n > 2) %>% 
  mutate(
    bias = p - mean_pred,
    n_diff = n*bias,
    LISTING_CTR_CODE = str_c(LISTING_CTR_CODE, ": n=", n),
    LISTING_CTR_CODE = fct_reorder(LISTING_CTR_CODE, n) #abs(n)
  ) %>% 
  ggplot(aes(x = bias, y = LISTING_CTR_CODE, fill = name)) + 
  geom_col(position = "dodge") + 
  # facet_wrap(~REGION, scales = "free_y", drop=TRUE) +
  labs(fill = "model")

Show the code
test_perf %>% 
  filter(name == "ensemble") %>% 
  filter(n > 10) %>% 
  mutate(
    bias = p - mean_pred,
    LISTING_CTR_CODE = fct_reorder(LISTING_CTR_CODE, n)
  ) %>% 
  ggplot(aes(x = mean_pred, y = p)) + geom_abline() + 
  geom_point(aes(size=n), shape=1) +
  scale_size_area() + 
  scale_x_continuous(breaks = seq(0, 1, by = .025)) + 
  scale_y_continuous(breaks = seq(0, 1, by = .025)) +
  labs(x = "avg survival prediction", y = "actual survival rate")