Fit an oblique random survival forest
Usage
orsf(
data,
formula,
control = orsf_control_fast(),
weights = NULL,
n_tree = 500,
n_split = 5,
n_retry = 3,
mtry = NULL,
leaf_min_events = 1,
leaf_min_obs = 5,
split_min_events = 5,
split_min_obs = 10,
split_min_stat = 3.841459,
oobag_pred_type = "surv",
oobag_pred_horizon = NULL,
oobag_eval_every = n_tree,
oobag_fun = NULL,
importance = "anova",
group_factors = TRUE,
tree_seeds = NULL,
attach_data = TRUE,
no_fit = FALSE,
na_action = "fail",
verbose_progress = FALSE,
...
)
orsf_train(object)
Arguments
- data
a data.frame, tibble, or data.table that contains the relevant variables.
- formula
(formula) The response on the left hand side should include a time variable, followed by a status variable, and may be written inside a call to Surv (see examples). The terms on the right are names of predictor variables.
- control
(orsf_control) An object returned from one of the
orsf_control
functions:orsf_control_fast (the default) uses a single iteration of Newton Raphson scoring to identify a linear combination of predictors.
orsf_control_cph uses Newton Raphson scoring until a convergence criteria is met.
orsf_control_net uses
glmnet
to identify linear combinations of predictors, similar to Jaeger (2019).orsf_control_custom allows the user to apply their own function to create linear combinations of predictors.
- weights
(numeric vector) Optional. If given, this input should have length equal to
nrow(data)
. Values inweights
are treated like replication weights, i.e., a value of 2 is the same thing as having 2 observations indata
, each containing a copy of the corresponding person's data.Use
weights
cautiously, asorsf
will count the number of observations and events prior to growing a node for a tree, so higher values inweights
will lead to deeper trees.- n_tree
(integer) the number of trees to grow. Default is
n_tree = 500.
- n_split
(integer) the number of cut-points assessed when splitting a node in decision trees. Default is
n_split = 5
.- n_retry
(integer) when a node can be split, but the current linear combination of inputs is unable to provide a valid split,
orsf
will try again with a new linear combination based on a different set of randomly selected predictors, up ton_retry
times. Default isn_retry = 3
. Setn_retry = 0
to prevent any retries.- mtry
(integer) Number of predictors randomly included as candidates for splitting a node. The default is the smallest integer greater than the square root of the number of total predictors, i.e.,
mtry = ceiling(sqrt(number of predictors))
- leaf_min_events
(integer) minimum number of events in a leaf node. Default is
leaf_min_events = 1
- leaf_min_obs
(integer) minimum number of observations in a leaf node. Default is
leaf_min_obs = 5
- split_min_events
(integer) minimum number of events required in a node to consider splitting it. Default is
split_min_events = 5
- split_min_obs
(integer) minimum number of observations required in a node to consider splitting it. Default is
split_min_obs = 10
.- split_min_stat
(double) minimum test statistic required to split a node. Default is 3.841459 for the log-rank test, which is roughly a p-value of 0.05
- oobag_pred_type
(character) The type of out-of-bag predictions to compute while fitting the ensemble. Valid options are
'none' : don't compute out-of-bag predictions
'risk' : predict the probability of having an event at or before
oobag_pred_horizon
.'surv' : 1 - risk.
'chf' : predict cumulative hazard function
Mortality ('mort')is not implemented for out of bag predictions yet, but it will be in a future update.
- oobag_pred_horizon
(numeric) A numeric value indicating what time should be used for out-of-bag predictions. Default is the median of the observed times, i.e.,
oobag_pred_horizon = median(time)
.- oobag_eval_every
(integer) The out-of-bag performance of the ensemble will be checked every
oobag_eval_every
trees. So, ifoobag_eval_every = 10
, then out-of-bag performance is checked after growing the 10th tree, the 20th tree, and so on. Default isoobag_eval_every = n_tree
.- oobag_fun
(function) to be used for evaluating out-of-bag prediction accuracy every
oobag_eval_every
trees. Whenoobag_fun = NULL
(the default), Harrell's C-statistic (1982) is used to evaluate accuracy. if you use your ownoobag_fun
note the following:oobag_fun
should have two inputs:y_mat
ands_vec
y_mat
is a two column matrix with first column named 'time', second named 'status's_vec
is a numeric vector containing predicted survival probabilities.oobag_fun
should return a numeric output of length 1
For more details, see the out-of-bag vignette.
- importance
(character) Indicate method for variable importance:
'none': no variable importance is computed.
'anova': compute analysis of variance (ANOVA) importance
'negate': compute negation importance
'permute': compute permutation importance
For details on these methods, see orsf_vi.
- group_factors
(logical) Only relevant if variable importance is being estimated. if
TRUE
, the importance of factor variables will be reported overall by aggregating the importance of individual levels of the factor. IfFALSE
, the importance of individual factor levels will be returned.- tree_seeds
(integer vector) Optional. if specified, random seeds will be set using the values in
tree_seeds[i]
before growing treei
. Two forests grown with the same number of trees and the same seeds will have the exact same out-of-bag samples, making out-of-bag error estimates of the forests more comparable. IfNULL
(the default), no seeds are set during the training process.- attach_data
(logical) if
TRUE
, a copy of the training data will be attached to the output. This is helpful if you plan on using functions like orsf_pd_oob or orsf_summarize_uni to interpret the forest using its training data. Default isTRUE
.- no_fit
(logical) if
TRUE
, model fitting steps are defined and saved, but training is not initiated. The object returned can be directly submitted toorsf_train()
so long asattach_data
isTRUE
.- na_action
(character) what should happen when
data
contains missing values (i.e.,NA
values). Valid options are:'fail' : an error is thrown if
data
containsNA
values'omit' : rows in
data
with incomplete data will be dropped'impute_meanmode' : missing values for continuous and categorical variables in
data
will be imputed using the mean and mode, respectively. Note that is this option is selected andattach_data
isTRUE
, the data attached to the output will be the imputed version ofdata
.
- verbose_progress
(logical) if
TRUE
, progress messages are printed in the console.- ...
Further arguments passed to or from other methods (not currently used).
- object
an untrained 'aorsf' object, created by setting
no_fit = TRUE
inorsf()
.
Details
This function is based on and similar to the ORSF
function
in the obliqueRSF
R package. The primary difference is that this
function runs much faster. The speed increase is attributable to better
management of memory (i.e., no unnecessary copies of inputs) and using
a Newton Raphson scoring algorithm to identify linear combinations of
inputs rather than performing penalized regression using routines in
glmnet
.The modified Newton Raphson scoring algorithm that this
function applies is an adaptation of the C++ routine developed by
Terry M. Therneau that fits Cox proportional hazards models
(see survival::coxph()
and more specifically survival::coxph.fit()
).
Details on inputs
formula:
The response in
formula
can be a survival object as returned by the Surv function, but can also just be the time and status variables. I.e.,Surv(time, status) ~ .
works just liketime + status ~ .
A
.
symbol on the right hand side is short-hand for using all variables indata
(omitting those on the left hand side offormula
) as predictors.The order of variables in the left hand side matters. i.e., writing
status + time ~ .
will makeorsf
assume yourstatus
variable is actually thetime
variable.The response variable can be a survival object stored in
data
. For example, y ~ . is a valid formula ifdata$y
inherits from theSurv
class.Although you can fit an oblique random survival forest with 1 predictor variable, your formula should have at least 2 predictors. The reason for this recommendation is that a linear combination of predictors is trivial if there is only one predictor.
mtry:
The mtry
parameter may be temporarily reduced to ensure there
are at least 2 events per predictor variable. This occurs when using
orsf_control_cph because coefficients in the Newton Raphson scoring
algorithm may become unstable when the number of covariates is
greater than or equal to the number of events. This reduction does not
occur when using orsf_control_net.
oobag_fun:
If oobag_fun
is specified, it will be used in to compute negation
importance or permutation importance, but it will not have any role
for ANOVA importance.
What is an oblique decision tree?
Decision trees are developed by splitting a set of training data into two new subsets, with the goal of having more similarity within the new subsets than between them. This splitting process is repeated on the resulting subsets of data until a stopping criterion is met. When the new subsets of data are formed based on a single predictor, the decision tree is said to be axis-based because the splits of the data appear perpendicular to the axis of the predictor. When linear combinations of variables are used instead of a single variable, the tree is oblique because the splits of the data are neither parallel nor at a right angle to the axis
Figure : Decision trees for classification with axis-based splitting (left) and oblique splitting (right). Cases are orange squares; controls are purple circles. Both trees partition the predictor space defined by variables X1 and X2, but the oblique splits do a better job of separating the two classes.
What is a random forest?
Random forests are collections of de-correlated decision trees. Predictions from each tree are aggregated to make an ensemble prediction for the forest. For more details, see Breiman at el, 2001.
Training, out-of-bag error, and testing
In random forests, each tree is grown with a bootstrapped version of the training set. Because bootstrap samples are selected with replacement, each bootstrapped training set contains about two-thirds of instances in the original training set. The 'out-of-bag' data are instances that are not in the bootstrapped training set. Each tree in the random forest can make predictions for its out-of-bag data, and the out-of-bag predictions can be aggregated to make an ensemble out-of-bag prediction. Since the out-of-bag data are not used to grow the tree, the accuracy of the ensemble out-of-bag predictions approximate the generalization error of the random forest. Generalization error refers to the error of a random forest's predictions when it is applied to predict outcomes for data that were not used to train it, i.e., testing data.
Missing data
Data passed to aorsf functions are not allowed to have missing values.
A user should impute missing values using an R package with that purpose,
such as recipes
or mlr3pipelines
.
Examples
First we load some relevant packages
set.seed(329730)
suppressPackageStartupMessages({
library(aorsf)
library(survival)
library(tidymodels)
library(tidyverse)
library(randomForestSRC)
library(ranger)
library(riskRegression)
library(obliqueRSF)
})
The entry-point into aorsf
is the standard call to orsf()
:
printing fit
provides quick descriptive summaries:
fit
## ---------- Oblique random survival forest
##
## Linear combinations: Accelerated
## N observations: 276
## N events: 111
## N trees: 500
## N predictors total: 17
## N predictors per node: 5
## Average leaves per tree: 24
## Min observations in leaf: 5
## Min events in leaf: 1
## OOB stat value: 0.84
## OOB stat type: Harrell's C-statistic
## Variable importance: anova
##
## -----------------------------------------
Model control
For these examples we will make use of the orsf_control_
functions to
build and compare models based on their out-of-bag predictions. We will
also standardize the out-of-bag samples using the input argument
tree_seeds
Accelerated linear combinations
The accelerated ORSF ensemble is the default because it has a nice balance of computational speed and prediction accuracy. It runs a single iteration of Newton Raphson scoring on the Cox partial likelihood function to find linear combinations of predictors.
fit_accel <- orsf(pbc_orsf,
control = orsf_control_fast(),
formula = Surv(time, status) ~ . - id,
tree_seeds = 1:500)
Linear combinations with Cox regression
orsf_control_cph
runs Cox regression in each non-terminal node of each
survival tree, using the regression coefficients to create linear
combinations of predictors:
fit_cph <- orsf(pbc_orsf,
control = orsf_control_cph(),
formula = Surv(time, status) ~ . - id,
tree_seeds = 1:500)
Linear combinations with penalized cox regression
orsf_control_net
runs penalized Cox regression in each non-terminal
node of each survival tree, using the regression coefficients to create
linear combinations of predictors. This can be really helpful if you
want to do feature selection within the node, but it is a lot slower
than the other options.
fit_net <- orsf(pbc_orsf,
# select 3 predictors out of 5 to be used in
# each linear combination of predictors.
control = orsf_control_net(df_target = 3),
formula = Surv(time, status) ~ . - id,
tree_seeds = 1:500)
Linear combinations with your own function
Let’s make two customized functions to identify linear combinations of predictors.
The first uses random coefficients
The second derives coefficients from principal component analysis.
f_pca <- function(x_node, y_node, w_node) { # estimate two principal components. pca <- stats::prcomp(x_node, rank. = 2) # use the second principal component to split the node pca$rotation[, 2L, drop = FALSE] }
We can plug these functions into orsf_control_custom()
, and then pass
the result into orsf()
:
fit_rando <- orsf(pbc_orsf,
Surv(time, status) ~ . - id,
control = orsf_control_custom(beta_fun = f_rando),
tree_seeds = 1:500)
fit_pca <- orsf(pbc_orsf,
Surv(time, status) ~ . - id,
control = orsf_control_custom(beta_fun = f_pca),
tree_seeds = 1:500)
So which fit seems to work best in this example? Let’s find out by evaluating the out-of-bag survival predictions.
risk_preds <- list(
accel = 1 - fit_accel$pred_oobag,
cph = 1 - fit_cph$pred_oobag,
net = 1 - fit_net$pred_oobag,
rando = 1 - fit_rando$pred_oobag,
pca = 1 - fit_pca$pred_oobag
)
sc <- Score(object = risk_preds,
formula = Surv(time, status) ~ 1,
data = pbc_orsf,
summary = 'IPA',
times = fit_accel$pred_horizon)
The AUC values, from highest to lowest:
sc$AUC$score[order(-AUC)]
## model times AUC se lower upper
## 1: net 1788 0.9107925 0.02116880 0.8693024 0.9522826
## 2: accel 1788 0.9106308 0.02178112 0.8679406 0.9533210
## 3: cph 1788 0.9072690 0.02120139 0.8657150 0.9488229
## 4: pca 1788 0.8915619 0.02335399 0.8457889 0.9373349
## 5: rando 1788 0.8900944 0.02228487 0.8464168 0.9337719
And the indices of prediction accuracy:
sc$Brier$score[order(-IPA), .(model, times, IPA)]
## model times IPA
## 1: accel 1788 0.4891448
## 2: cph 1788 0.4687734
## 3: net 1788 0.4652211
## 4: rando 1788 0.4011573
## 5: pca 1788 0.3845911
## 6: Null model 1788 0.0000000
From inspection,
the PCA approach has the highest discrimination, showing that you can do very well with just a two line custom function.
the accelerated ORSF has the highest index of prediction accuracy
the random coefficients generally don’t do that well.
tidymodels
This example uses tidymodels
functions but stops short of using an
official tidymodels
workflow. I am working on getting aorsf
pulled
into the censored
package and I will update this with real workflows
if that happens!
Comparing ORSF with other learners
Start with a recipe to pre-process data
imputer <- recipe(pbc_orsf, formula = time + status ~ .) %>%
step_impute_mean(all_numeric_predictors()) %>%
step_impute_mode(all_nominal_predictors())
Next create a 10-fold cross validation object and pre-process the data:
# 10-fold cross validation; make a container for the pre-processed data
analyses <- vfold_cv(data = pbc_orsf, v = 10) %>%
mutate(recipe = map(splits, ~prep(imputer, training = training(.x))),
train = map(recipe, juice),
test = map2(splits, recipe, ~bake(.y, new_data = testing(.x))))
analyses
## # 10-fold cross-validation
## # A tibble: 10 x 5
## splits id recipe train test
## <list> <chr> <list> <list> <list>
## 1 <split [248/28]> Fold01 <recipe> <tibble [248 x 20]> <tibble [28 x 20]>
## 2 <split [248/28]> Fold02 <recipe> <tibble [248 x 20]> <tibble [28 x 20]>
## 3 <split [248/28]> Fold03 <recipe> <tibble [248 x 20]> <tibble [28 x 20]>
## 4 <split [248/28]> Fold04 <recipe> <tibble [248 x 20]> <tibble [28 x 20]>
## 5 <split [248/28]> Fold05 <recipe> <tibble [248 x 20]> <tibble [28 x 20]>
## 6 <split [248/28]> Fold06 <recipe> <tibble [248 x 20]> <tibble [28 x 20]>
## 7 <split [249/27]> Fold07 <recipe> <tibble [249 x 20]> <tibble [27 x 20]>
## 8 <split [249/27]> Fold08 <recipe> <tibble [249 x 20]> <tibble [27 x 20]>
## 9 <split [249/27]> Fold09 <recipe> <tibble [249 x 20]> <tibble [27 x 20]>
## 10 <split [249/27]> Fold10 <recipe> <tibble [249 x 20]> <tibble [27 x 20]>
Define functions for a ‘workflow’ with randomForestSRC
, ranger
, and
aorsf
.
rfsrc_wf <- function(train, test, pred_horizon){
# rfsrc does not like tibbles, so cast input data into data.frames
train <- as.data.frame(train)
test <- as.data.frame(test)
rfsrc(formula = Surv(time, status) ~ ., data = train) %>%
predictRisk(newdata = test, times = pred_horizon) %>%
as.numeric()
}
ranger_wf <- function(train, test, pred_horizon){
ranger(Surv(time, status) ~ ., data = train) %>%
predictRisk(newdata = test, times = pred_horizon) %>%
as.numeric()
}
aorsf_wf <- function(train, test, pred_horizon){
train %>%
orsf(Surv(time, status) ~ .,) %>%
predict(new_data = test, pred_horizon = pred_horizon) %>%
as.numeric()
}
Run the ‘workflows’ on each fold:
# 5 year risk prediction
ph <- 365.25 * 5
results <- analyses %>%
transmute(test,
pred_aorsf = map2(train, test, aorsf_wf, pred_horizon = ph),
pred_rfsrc = map2(train, test, rfsrc_wf, pred_horizon = ph),
pred_ranger = map2(train, test, ranger_wf, pred_horizon = ph))
Next unnest each column to get back a tibble
with all of the testing
data and predictions.
results <- results %>%
unnest(everything())
glimpse(results)
## Rows: 276
## Columns: 23
## $ id <int> 2, 16, 27, 66, 79, 97, 107, 116, 136, 137, 158, 189, 193, ~
## $ trt <fct> d_penicill_main, placebo, placebo, d_penicill_main, d_peni~
## $ age <dbl> 56.44627, 40.44353, 54.43943, 46.45311, 46.51608, 71.89322~
## $ sex <fct> f, f, f, m, f, m, f, f, f, f, f, f, f, f, f, f, f, f, f, f~
## $ ascites <fct> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0~
## $ hepato <fct> 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1~
## $ spiders <fct> 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1~
## $ edema <fct> 0, 0, 0.5, 0, 0, 0.5, 0, 0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0~
## $ bili <dbl> 1.1, 0.7, 21.6, 1.4, 0.8, 2.0, 0.6, 3.0, 0.8, 1.1, 3.4, 1.~
## $ chol <int> 302, 204, 175, 427, 315, 420, 212, 458, 263, 399, 450, 360~
## $ albumin <dbl> 4.14, 3.66, 3.31, 3.70, 4.24, 3.26, 4.03, 3.63, 3.35, 3.60~
## $ copper <int> 54, 28, 221, 105, 13, 62, 10, 74, 27, 79, 32, 52, 267, 76,~
## $ alk.phos <dbl> 7394.8, 685.0, 3697.4, 1909.0, 1637.0, 3196.0, 648.0, 1588~
## $ ast <dbl> 113.52, 72.85, 101.91, 182.90, 170.50, 77.50, 71.30, 106.9~
## $ trig <int> 88, 58, 168, 171, 70, 91, 77, 382, 69, 152, 118, 164, 157,~
## $ platelet <int> 221, 198, 80, 123, 426, 344, 316, 438, 206, 344, 313, 256,~
## $ protime <dbl> 10.6, 10.8, 12.0, 11.0, 10.9, 11.4, 17.1, 9.9, 9.8, 10.1, ~
## $ stage <ord> 3, 3, 4, 3, 3, 3, 1, 3, 2, 2, 2, 3, 4, 4, 2, 2, 3, 3, 4, 4~
## $ time <int> 4500, 3672, 77, 4191, 3707, 611, 3388, 3336, 3098, 2990, 2~
## $ status <dbl> 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0~
## $ pred_aorsf <dbl> 0.21650571, 0.01569191, 0.93095617, 0.36737089, 0.12868206~
## $ pred_rfsrc <dbl> 0.15202784, 0.01104486, 0.81913559, 0.20173550, 0.13806608~
## $ pred_ranger <dbl> 0.11418963, 0.02130315, 0.77073269, 0.22130305, 0.18419972~
And finish by aggregating the predictions and computing performance in the testing data. Note that I am computing one statistic for all predictions instead of computing one statistic for each fold. This approach is fine when you have smaller testing sets and/or small event counts.
Score(
object = list(aorsf = results$pred_aorsf,
rfsrc = results$pred_rfsrc,
ranger = results$pred_ranger),
formula = Surv(time, status) ~ 1,
data = results,
summary = 'IPA',
times = ph
)
##
## Metric AUC:
##
## Results by model:
##
## model times AUC lower upper
## 1: aorsf 1826 90.1 85.7 94.6
## 2: rfsrc 1826 89.4 85.0 93.7
## 3: ranger 1826 90.1 85.9 94.3
##
## Results of model comparisons:
##
## times model reference delta.AUC lower upper p
## 1: 1826 rfsrc aorsf -0.7 -2.3 0.8 0.4
## 2: 1826 ranger aorsf -0.0 -1.7 1.6 1.0
## 3: 1826 ranger rfsrc 0.7 -0.4 1.8 0.2
##
## NOTE: Values are multiplied by 100 and given in %.
## NOTE: The higher AUC the better.
##
## Metric Brier:
##
## Results by model:
##
## model times Brier lower upper IPA
## 1: Null model 1826.25 20.5 18.1 22.9 0.0
## 2: aorsf 1826.25 11.1 8.8 13.4 45.8
## 3: rfsrc 1826.25 12.0 9.8 14.1 41.6
## 4: ranger 1826.25 11.8 9.7 13.9 42.5
##
## Results of model comparisons:
##
## times model reference delta.Brier lower upper p
## 1: 1826.25 aorsf Null model -9.4 -12.1 -6.6 2.423961e-11
## 2: 1826.25 rfsrc Null model -8.5 -10.8 -6.2 2.104905e-13
## 3: 1826.25 ranger Null model -8.7 -11.0 -6.4 1.802417e-13
## 4: 1826.25 rfsrc aorsf 0.9 -0.0 1.7 5.277607e-02
## 5: 1826.25 ranger aorsf 0.7 -0.1 1.5 1.008730e-01
## 6: 1826.25 ranger rfsrc -0.2 -0.7 0.3 4.550782e-01
##
## NOTE: Values are multiplied by 100 and given in %.
## NOTE: The lower Brier the better, the higher IPA the better.
From inspection,
aorsf
obtained slightly higher discrimination (AUC)aorsf
obtained higher index of prediction accuracy (IPA)Way to go,
aorsf
mlr3 pipelines
Warning: this code may or may not run depending on your current
version of mlr3proba
. First we load some additional mlr3
libraries.
suppressPackageStartupMessages({
library(mlr3verse)
library(mlr3proba)
library(mlr3extralearners)
library(mlr3viz)
library(mlr3benchmark)
})
Next we’ll define some tasks for our learners to engage with.
# Mayo Clinic Primary Biliary Cholangitis Data
task_pbc <-
TaskSurv$new(
id = 'pbc',
backend = select(pbc_orsf, -id) %>%
mutate(stage = as.numeric(stage)),
time = "time",
event = "status"
)
# Veteran's Administration Lung Cancer Trial
data(veteran, package = "randomForestSRC")
task_veteran <-
TaskSurv$new(
id = 'veteran',
backend = veteran,
time = "time",
event = "status"
)
# NKI 70 gene signature
data_nki <- OpenML::getOMLDataSet(data.id = 1228)
task_nki <-
TaskSurv$new(
id = 'nki',
backend = data_nki$data,
time = "time",
event = "event"
)
# Gene Expression-Based Survival Prediction in Lung Adenocarcinoma
data_lung <- OpenML::getOMLDataSet(data.id = 1245)
task_lung <-
TaskSurv$new(
id = 'nki',
backend = data_lung$data %>%
mutate(OS_event = as.numeric(OS_event) -1),
time = "OS_years",
event = "OS_event"
)
# Chemotherapy for Stage B/C colon cancer
# (there are two rows per person, one for death
# and the other for recurrence, hence the two tasks)
task_colon_death <-
TaskSurv$new(
id = 'colon_death',
backend = survival::colon %>%
filter(etype == 2) %>%
drop_na() %>%
# drop id, redundant variables
select(-id, -study, -node4, -etype),
mutate(OS_event = as.numeric(OS_event) -1),
time = "time",
event = "status"
)
task_colon_recur <-
TaskSurv$new(
id = 'colon_death',
backend = survival::colon %>%
filter(etype == 1) %>%
drop_na() %>%
# drop id, redundant variables
select(-id, -study, -node4, -etype),
mutate(OS_event = as.numeric(OS_event) -1),
time = "time",
event = "status"
)
# putting them all together
tasks <- list(task_pbc,
task_veteran,
task_nki,
task_lung,
task_colon_death,
task_colon_recur,
# add a few more pre-made ones
tsk("actg"),
tsk('gbcs'),
tsk('grace'),
tsk("unemployment"),
tsk("whas"))
Now we can make a benchmark designed to compare our three favorite learners:
# Learners with default parameters
learners <- lrns(c("surv.ranger", "surv.rfsrc", "surv.aorsf"))
# Brier (Graf) score, c-index and training time as measures
measures <- msrs(c("surv.graf", "surv.cindex", "time_train"))
# Benchmark with 5-fold CV
design <- benchmark_grid(
tasks = tasks,
learners = learners,
resamplings = rsmps("cv", folds = 5)
)
benchmark_result <- benchmark(design)
bm_scores <- benchmark_result$score(measures, predict_sets = "test")
Let’s look at the overall results:
bm_scores %>%
select(task_id, learner_id, surv.graf, surv.cindex, time_train) %>%
group_by(learner_id) %>%
filter(!is.infinite(surv.graf)) %>%
summarize(
across(
.cols = c(surv.graf, surv.cindex, time_train),
.fns = mean,
na.rm = TRUE
)
)
## # A tibble: 3 x 4
## learner_id surv.graf surv.cindex time_train
## <chr> <dbl> <dbl> <dbl>
## 1 surv.aorsf 0.151 0.729 0.345
## 2 surv.ranger 0.167 0.706 2.54
## 3 surv.rfsrc 0.156 0.715 0.783
From inspection,
aorsf
appears to have a higher expected value for ‘surv.cindex’ (higher is better)aorsf
appears to have a lower expected value for ‘surv.graf’ (lower is better)aorsf
has the lowest training time.
the lower training time for aorsf
is likely due to the fact that there
are many unique event times in the benchmark tasks. ranger
and rfsrc
create grids of time points based on each unique event time in each leaf
of each decision tree, whereas aorsf
also uses a grid but restricts it
to the unique event times among observations in the current leaf.
References
Harrell FE, Califf RM, Pryor DB, Lee KL, Rosati RA. Evaluating the Yield of Medical Tests. JAMA 1982; 247(18):2543-2546. DOI: 10.1001/jama.1982.03320430047030
Breiman L. Random forests. Machine learning 2001 Oct; 45(1):5-32. DOI: 10.1023/A:1010933404324
Ishwaran H, Kogalur UB, Blackstone EH, Lauer MS. Random survival forests. Annals of applied statistics 2008 Sep; 2(3):841-60. DOI: 10.1214/08-AOAS169
Jaeger BC, Long DL, Long DM, Sims M, Szychowski JM, Min YI, Mcclure LA, Howard G, Simon N. Oblique random survival forests. Annals of applied statistics 2019 Sep; 13(3):1847-83. DOI: 10.1214/19-AOAS1261
Jaeger BC, Welden S, Lenoir K, Speiser JL, Segar MW, Pandey A, Pajewski NM. Accelerated and interpretable oblique random survival forests. arXiv e-prints 2022 Aug; arXiv-2208. URL: https://arxiv.org/abs/2208.01129