Partial dependence (PD)
Partial dependence (PD) shows the expected prediction from a model as a function of a single predictor or multiple predictors. The expectation is marginalized over the values of all other predictors, giving something like a multivariable adjusted estimate of the model’s prediction.
Begin by fitting an ORSF ensemble. Set a prediction horizon of 5 years when we fit the ensemble so that any aorsf
function that we pass this ensemble to will assume we want to compute predictions at 5 years.
library(aorsf)
pred_horizon <- 365.25 * 5
set.seed(329730)
index_train <- sample(nrow(pbc_orsf), 150)
pbc_orsf_train <- pbc_orsf[index_train, ]
pbc_orsf_test <- pbc_orsf[-index_train, ]
fit <- orsf(data = pbc_orsf_train,
formula = Surv(time, status) ~ . - id,
n_tree = 50,
oobag_pred_horizon = pred_horizon)
fit
#> ---------- Oblique random survival forest
#>
#> Linear combinations: Accelerated Cox regression
#> N observations: 150
#> N events: 52
#> N trees: 50
#> N predictors total: 17
#> N predictors per node: 5
#> Average leaves per tree: 10.14
#> Min observations in leaf: 5
#> Min events in leaf: 1
#> OOB stat value: 0.83
#> OOB stat type: Harrell's C-index
#> Variable importance: anova
#>
#> -----------------------------------------
Three ways to compute PD
You can compute PD three ways with aorsf
:
-
using in-bag predictions for the training data
pd_inb <- orsf_pd_inb(fit, pred_spec = list(bili = 1:5)) pd_inb #> pred_horizon bili mean lwr medn upr #> 1: 1826.25 1 0.2009424 0.01590826 0.09399154 0.8021860 #> 2: 1826.25 2 0.2333312 0.02694801 0.13328302 0.8219914 #> 3: 1826.25 3 0.2616408 0.03156470 0.16551050 0.8367423 #> 4: 1826.25 4 0.3070941 0.05820874 0.21875904 0.8547467 #> 5: 1826.25 5 0.3463971 0.09854666 0.26899889 0.8547467
-
using out-of-bag predictions for the training data
pd_oob <- orsf_pd_oob(fit, pred_spec = list(bili = 1:5)) pd_oob #> pred_horizon bili mean lwr medn upr #> 1: 1826.25 1 0.2116003 0.01202018 0.1133762 0.7639032 #> 2: 1826.25 2 0.2440849 0.01766977 0.1684559 0.8025264 #> 3: 1826.25 3 0.2750214 0.02522565 0.2008733 0.8114771 #> 4: 1826.25 4 0.3194273 0.03177694 0.2427813 0.8252211 #> 5: 1826.25 5 0.3542926 0.05696909 0.2757469 0.8256011
-
using predictions for a new set of data
pd_test <- orsf_pd_new(fit, new_data = pbc_orsf_test, pred_spec = list(bili = 1:5)) pd_test #> pred_horizon bili mean lwr medn upr #> 1: 1826.25 1 0.2373354 0.02018629 0.1706567 0.7643862 #> 2: 1826.25 2 0.2694548 0.03099876 0.2111509 0.7859409 #> 3: 1826.25 3 0.3024748 0.04413066 0.2502721 0.8045073 #> 4: 1826.25 4 0.3471510 0.06308585 0.2960804 0.8126026 #> 5: 1826.25 5 0.3834739 0.10542180 0.3578699 0.8536859
in-bag PD indicates relationships that the model has learned during training. This is helpful if your goal is to interpret the model.
out-of-bag PD indicates relationships that the model has learned during training but using the out-of-bag data simulates application of the model to new data. This is helpful if you want to test your model’s reliability or fairness in new data but you don’t have access to a large testing set.
new data PD shows how the model predicts outcomes for observations it has not seen. This is helpful if you want to test your model’s reliability or fairness.
Automatic variable values
Use pred_spec_auto()
if you know the variables you want to check out but you don’t have a specific set of values in mind:
orsf_pd_inb(fit, pred_spec_auto(bili))
#> pred_horizon bili mean lwr medn upr
#> 1: 1826.25 0.600 0.1916845 0.01393675 0.09277438 0.7862712
#> 2: 1826.25 0.700 0.1931983 0.01407259 0.09367646 0.7862712
#> 3: 1826.25 1.300 0.2111518 0.01875820 0.10058691 0.8056634
#> 4: 1826.25 3.175 0.2707361 0.03183542 0.18335844 0.8402006
#> 5: 1826.25 7.110 0.4345519 0.22022836 0.36615264 0.8610052
pred_spec_auto()
lets you specify a variable in your model with or without quotes, and then assign values for that variable on your behalf. For continuous variables, it uses quantiles (10, 25, 50, 75, and 90). For nominal variables, it uses unique categories.
Let’s re-fit our ORSF model to all available data before proceeding to the next sections.
One variable, one horizon
Computing PD for a single variable is straightforward:
pd_sex <- orsf_pd_oob(fit, pred_spec = pred_spec_auto(sex))
pd_sex
#> pred_horizon sex mean lwr medn upr
#> 1: 1826.25 m 0.3420353 0.01698975 0.2493781 0.9219013
#> 2: 1826.25 f 0.3078882 0.01046657 0.1877641 0.9212616
The output shows that the expected predicted mortality risk for men is substantially higher than women at 5 years after baseline.
One variable, moving horizon
What if the effect of a predictor varies over time? PD can show this.
pd_sex_tv <- orsf_pd_oob(fit, pred_spec = pred_spec_auto(sex),
pred_horizon = seq(365, 365*5))
ggplot(pd_sex_tv, aes(x = pred_horizon, y = mean, color = sex)) +
geom_line() +
labs(x = 'Time since baseline',
y = 'Expected risk')
From inspection, we can see that males have higher risk than females and the difference in that risk grows over time. This can also be seen by viewing the ratio of expected risk over time:
library(data.table)
ratio_tv <- pd_sex_tv[
, .(ratio = mean[sex == 'm'] / mean[sex == 'f']), by = pred_horizon
]
ggplot(ratio_tv, aes(x = pred_horizon, y = ratio)) +
geom_line(color = 'grey') +
geom_smooth(color = 'black', se = FALSE) +
labs(x = 'time since baseline',
y = 'ratio in expected risk for males versus females')
#> `geom_smooth()` using method = 'gam' and formula = 'y ~ s(x, bs = "cs")'
Multiple variables, marginally
If you want to compute PD marginally for multiple variables, just list the variable values in pred_spec
and specify expand_grid = FALSE
.
pd_two_vars <-
orsf_pd_oob(fit,
pred_spec = pred_spec_auto(sex, bili),
expand_grid = FALSE)
pd_two_vars
#> pred_horizon variable value level mean lwr medn upr
#> 1: 1826.25 sex NA m 0.3420353 0.01698975 0.2493781 0.9219013
#> 2: 1826.25 sex NA f 0.3078882 0.01046657 0.1877641 0.9212616
#> 3: 1826.25 bili 0.600 <NA> 0.2455223 0.01045620 0.1643707 0.8391940
#> 4: 1826.25 bili 0.800 <NA> 0.2505248 0.01046657 0.1666401 0.8360690
#> 5: 1826.25 bili 1.400 <NA> 0.2643678 0.01693072 0.1840421 0.8501067
#> 6: 1826.25 bili 3.525 <NA> 0.3841995 0.08224197 0.3260132 0.8888114
#> 7: 1826.25 bili 7.250 <NA> 0.5193563 0.18367045 0.5021500 0.9141342
To get a view of partial dependence for any number of variables in the training data, use orsf_summarize_uni()
. This function computes out-of-bag PD for the most important n_variables
and returns a nicely formatted view of the output:
pd_smry <- orsf_summarize_uni(fit, n_variables = 4)
pd_smry
#>
#> -- ascites (VI Rank: 1) ------------------------
#>
#> |---------------- Risk ----------------|
#> Value Mean Median 25th % 75th %
#> 0 0.3094268 0.2048280 0.07392497 0.5228134
#> 1 0.5036637 0.4638581 0.32705425 0.6882956
#>
#> -- bili (VI Rank: 2) ---------------------------
#>
#> |---------------- Risk ----------------|
#> Value Mean Median 25th % 75th %
#> 0.80 0.2505248 0.1666401 0.06233100 0.3964744
#> 1.40 0.2643678 0.1840421 0.07164115 0.4040525
#> 3.52 0.3841995 0.3260132 0.19553863 0.5321580
#>
#> -- edema (VI Rank: 3) --------------------------
#>
#> |---------------- Risk ----------------|
#> Value Mean Median 25th % 75th %
#> 0 0.3020595 0.1994849 0.06741347 0.5174422
#> 0.5 0.3808198 0.3181430 0.14075530 0.5974875
#> 1 0.4503680 0.3985665 0.24753838 0.6162484
#>
#> -- copper (VI Rank: 4) -------------------------
#>
#> |---------------- Risk ----------------|
#> Value Mean Median 25th % 75th %
#> 42.8 0.2754068 0.1726686 0.06179491 0.4282714
#> 74.0 0.2903369 0.1897076 0.07682290 0.4545510
#> 129 0.3412883 0.2659864 0.11419673 0.5411165
#>
#> Predicted risk at time t = 1826.25 for top 4 predictors
This ‘summary’ object can be converted into a data.table
for downstream plotting and tables.
head(as.data.table(pd_smry))
#> variable importance Value Mean Median 25th % 75th %
#> 1: ascites 0.5000000 0 0.3094268 0.2048280 0.07392497 0.5228134
#> 2: ascites 0.5000000 1 0.5036637 0.4638581 0.32705425 0.6882956
#> 3: bili 0.3880597 0.80 0.2505248 0.1666401 0.06233100 0.3964744
#> 4: bili 0.3880597 1.40 0.2643678 0.1840421 0.07164115 0.4040525
#> 5: bili 0.3880597 3.52 0.3841995 0.3260132 0.19553863 0.5321580
#> 6: edema 0.3423269 0 0.3020595 0.1994849 0.06741347 0.5174422
#> pred_horizon level
#> 1: 1826.25 0
#> 2: 1826.25 1
#> 3: 1826.25 <NA>
#> 4: 1826.25 <NA>
#> 5: 1826.25 <NA>
#> 6: 1826.25 0
Multiple variables, jointly
PD can show the expected value of a model’s predictions as a function of a specific predictor, or as a function of multiple predictors. For instance, we can estimate predicted risk as a joint function of bili
, edema
, and trt
:
pred_spec = pred_spec_auto(bili, edema, trt)
pd_bili_edema <- orsf_pd_oob(fit, pred_spec)
library(ggplot2)
ggplot(pd_bili_edema, aes(x = bili, y = medn, col = trt, linetype = edema)) +
geom_line() +
labs(y = 'Expected predicted risk')
From inspection,
the model’s predictions indicate slightly lower risk for the placebo group, and these do not seem to change much at different values of
bili
oredema
.There is a clear increase in predicted risk with higher levels of
edema
and with higher levels ofbili
-
the slope of predicted risk as a function of
bili
appears highest among patients withedema
of 0.5. Is the effect ofbili
modified byedema
being 0.5? A quick sanity check withcoxph
suggests there is.library(survival) pbc_orsf$edema_05 <- ifelse(pbc_orsf$edema == '0.5', 'yes', 'no') fit_cph <- coxph(Surv(time,status) ~ edema_05 * bili, data = pbc_orsf) anova(fit_cph) #> Analysis of Deviance Table #> Cox model: response is Surv(time, status) #> Terms added sequentially (first to last) #> #> loglik Chisq Df Pr(>|Chi|) #> NULL -550.19 #> edema_05 -546.83 6.7248 1 0.009508 ** #> bili -513.59 66.4689 1 3.555e-16 *** #> edema_05:bili -510.54 6.1112 1 0.013433 * #> --- #> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Individual conditional expectations (ICE)
Unlike partial dependence, which shows the expected prediction as a function of one or multiple predictors, individual conditional expectations (ICE) show the prediction for an individual observation as a function of a predictor.
Just like PD, we can compute ICE using in-bag, out-of-bag, or testing data, and the same principles apply. We’ll use out-of-bag estimates here.
Visualizing ICE curves
Inspecting the ICE curves for each observation can help identify whether there is heterogeneity in a model’s predictions. I.e., does the effect of the variable follow the same pattern for all the data, or are there groups where the variable impacts risk differently?
I am going to turn off boundary checking in orsf_ice_oob
by setting boundary_checks = FALSE
, and this will allow me to generate ICE curves that go beyond the 90th percentile of bili
.
pred_spec <- list(bili = seq(1, 10, length.out = 25))
ice_oob <- orsf_ice_oob(fit, pred_spec, boundary_checks = FALSE)
ice_oob
#> id_variable id_row pred_horizon bili pred
#> 1: 1 1 1826.25 1 0.9000491
#> 2: 1 2 1826.25 1 0.7728780
#> 3: 1 3 1826.25 1 0.6262776
#> 4: 1 4 1826.25 1 0.8016667
#> 5: 1 5 1826.25 1 0.5701027
#> ---
#> 6896: 25 272 1826.25 10 0.6589833
#> 6897: 25 273 1826.25 10 0.4139382
#> 6898: 25 274 1826.25 10 0.2736971
#> 6899: 25 275 1826.25 10 0.5723036
#> 6900: 25 276 1826.25 10 0.5832437
id_variable
is an identifier for the current value of the variable(s) that are in the data. It is redundant if you only have one variable, but helpful if there are multiple variables.id_row
is an identifier for the observation in the original data. It is used to group an observation’s predictions together in plots.
For plots, it is helpful to scale the ICE data. I subtract the initial value of predicted risk (i.e., when bili = 1
) from each observation’s conditional expectation values. So,
Every curve start at 0
-
The plot shows change in predicted risk as a function of
bili
.ice_oob[, pred_subtract := rep(pred[id_variable==1], times=25)] ice_oob[, pred := pred - pred_subtract]
Now we can visualize the curves.
library(ggplot2)
ggplot(ice_oob, aes(x = bili,
y = pred,
group = id_row)) +
geom_line(alpha = 0.15) +
labs(y = 'Change in predicted risk') +
geom_smooth(se = FALSE, aes(group = 1))
#> `geom_smooth()` using method = 'gam' and formula = 'y ~ s(x, bs = "cs")'
From inspection of the figure,
Most of the individual slopes cluster around the overall trend - Good!
A small number of individual slopes appear to be flat. It may be helpful to investigate this further.
Limitations of PD
Partial dependence has a number of known limitations and assumptions that users should be aware of (see Hooker, 2021). In particular, partial dependence is less intuitive when >2 predictors are examined jointly, and it is assumed that the feature(s) for which the partial dependence is computed are not correlated with other features (this is likely not true in many cases). Accumulated local effect plots can be used (see here) in the case where feature independence is not a valid assumption.
References
- Giles Hooker, Lucas Mentch, Siyu Zhou. Unrestricted Permutation forces Extrapolation: Variable Importance Requires at least One More Model, or There Is No Free Variable Importance. arXiv e-prints 2021 Oct; arXiv-1905. URL: https://doi.org/10.48550/arXiv.1905.03151