Skip to content

Commit fa741d9

Browse files
topepohfrickEmilHvitfeldt
authored
Survival censoring weights (#897)
* helper functions * doc, version, and test update * small doc update * update name change * convert to a portable standalone document with unexported functions * fix issues testing tensorflow * IPCW functions * time always >= 0 * unit tests * typo * additional test * version bump * move some tests to extratests * Apply suggestions from code review Co-authored-by: Hannah Frick <[email protected]> Co-authored-by: Emil Hvitfeldt <[email protected]> * Apply suggestions from code review --------- Co-authored-by: Hannah Frick <[email protected]> Co-authored-by: Emil Hvitfeldt <[email protected]>
1 parent dff211b commit fa741d9

11 files changed

+361
-22
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.4.9003
3+
Version: 1.0.4.9004
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(.censoring_weights_graf,default)
4+
S3method(.censoring_weights_graf,model_fit)
5+
S3method(.censoring_weights_graf,workflow)
36
S3method(augment,model_fit)
47
S3method(autoplot,glmnet)
58
S3method(autoplot,model_fit)
@@ -144,6 +147,7 @@ S3method(varying_args,model_spec)
144147
S3method(varying_args,recipe)
145148
S3method(varying_args,step)
146149
export("%>%")
150+
export(.censoring_weights_graf)
147151
export(.check_glmnet_penalty_fit)
148152
export(.check_glmnet_penalty_predict)
149153
export(.cols)

R/ipcw.R

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ trunc_probs <- function(probs, trunc = 0.01) {
1919
}
2020

2121
.filter_eval_time <- function(eval_time, fail = TRUE) {
22+
if (!is.null(eval_time)) {
23+
eval_time <- as.numeric(eval_time)
24+
}
2225
# will still propagate nulls:
2326
eval_time <- eval_time[!is.na(eval_time)]
2427
eval_time <- unique(eval_time)
@@ -32,3 +35,199 @@ trunc_probs <- function(probs, trunc = 0.01) {
3235
}
3336
eval_time
3437
}
38+
39+
add_dot_row_to_weights <- function(dat, rows = NULL) {
40+
if (is.null(rows)) {
41+
dat <- add_rowindex(dat)
42+
} else {
43+
m <- length(rows)
44+
n <- nrow(dat)
45+
if (m != n) {
46+
rlang::abort(
47+
glue::glue(
48+
"The length of 'rows' ({m}) should be equal to the number of rows in 'data' ({n})"
49+
)
50+
)
51+
}
52+
dat$.row <- rows
53+
}
54+
dat
55+
}
56+
57+
.check_censor_model <- function(x) {
58+
nms <- names(x)
59+
if (!any(nms == "censor_probs")) {
60+
rlang::abort("Please refit the model with parsnip version 1.0.4 or greater.")
61+
}
62+
invisible(NULL)
63+
}
64+
65+
# nocov start
66+
# these are tested in extratests
67+
# ------------------------------------------------------------------------------
68+
# Brier score helpers. Most of this is based off of Graf, E., Schmoor, C.,
69+
# Sauerbrei, W. and Schumacher, M. (1999), Assessment and comparison of
70+
# prognostic classification schemes for survival data. _Statist. Med._, 18:
71+
# 2529-2545.
72+
73+
# We need to use the time of analysis to determine what time to use to evaluate
74+
# the IPCWs.
75+
76+
graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
77+
event_time <- .extract_surv_time(surv_obj)
78+
status <- .extract_surv_status(surv_obj)
79+
is_event_before_t <- event_time <= eval_time & status == 1
80+
is_censored <- event_time > eval_time
81+
82+
# Three possible contributions to the statistic from Graf 1999
83+
84+
# Censoring time before eval_time, no contribution (Graf category 3)
85+
weight_time <- rep(NA_real_, length(event_time))
86+
87+
# A real event prior to eval_time (Graf category 1)
88+
weight_time[is_event_before_t] <- event_time[is_event_before_t] - eps
89+
90+
# Observed time greater than eval_time (Graf category 2)
91+
weight_time[is_censored] <- eval_time - eps
92+
93+
weight_time <- ifelse(weight_time < 0, 0, weight_time)
94+
95+
res <- tibble::tibble(surv = surv_obj, weight_time = weight_time, eval_time)
96+
add_dot_row_to_weights(res, rows)
97+
}
98+
99+
# ------------------------------------------------------------------------------
100+
#' Calculations for inverse probability of censoring weights (IPCW)
101+
#'
102+
#' The method of Graf _et al_ (1999) is used to compute weights at specific
103+
#' evaluation times that can be used to help measure a model's time-dependent
104+
#' performance (e.g. the time-dependent Brier score or the area under the ROC
105+
#' curve).
106+
#' @param data A data frame with a column containing a [survival::Surv()] object.
107+
#' @param predictors Not currently used. A potential future slot for models with
108+
#' informative censoring based on columns in `data`.
109+
#' @param rows An optional integer vector with length equal to the number of
110+
#' rows in `data` that is used to index the original data. The default is to
111+
#' use a fresh index on data (i.e. `1:nrow(data)`).
112+
#' @param eval_time A vector of finite, non-negative times at which to
113+
#' compute the probability of censoring and the corresponding weights.
114+
#' @param object A fitted parsnip model object or fitted workflow with a mode
115+
#' of "censored regression".
116+
#' @param trunc A potential lower bound for the probability of censoring to avoid
117+
#' very large weight values.
118+
#' @param eps A small value that is subtracted from the evaluation time when
119+
#' computing the censoring probabilities. See Details below.
120+
#' @return A tibble with columns `.row`, `eval_time`, `.prob_cens` (the
121+
#' probability of being censored just prior to the evaluation time), and
122+
#' `.weight_cens` (the inverse probability of censoring weight).
123+
#' @details
124+
#'
125+
#' A probability that the data are censored immediately prior to a specific
126+
#' time is computed. To do this, we must determine what time to
127+
#' make the prediction. There are two time values for each row of the data set:
128+
#' the observed time (either censored or not) and the time that the model is
129+
#' being evaluated at (e.g. the survival function prediction at some time point),
130+
#' which is constant across rows. .
131+
#'
132+
#' From Graf _et al_ (1999) there are three cases:
133+
#'
134+
#' - If the observed time is a censoring time and that is before the
135+
#' evaluation time, the data point should make no contribution to the
136+
#' performance metric (their "category 3"). These values have a missing
137+
#' value for their probability estimate (and also for their weight column).
138+
#'
139+
#' - If the observed time corresponds to an actual event, and that time is
140+
#' prior to the evaluation time (category 1), the probability of being
141+
#' censored is predicted at the observed time (minus an epsilon).
142+
#'
143+
#' - If the observed time is _after_ the evaluation time (category 2), regardless of
144+
#' the status, the probability of being censored is predicted at the evaluation
145+
#' time (minus an epsilon).
146+
#'
147+
#' The epsilon is used since, we would not have actual information at time `t`
148+
#' for a data point being predicted at time `t` (only data prior to time `t`
149+
#' should be available).
150+
#'
151+
#' After the censoring probability is computed, the `trunc` option is used to
152+
#' avoid using numbers pathologically close to zero. After this, the weight is
153+
#' computed by inverting the censoring probability.
154+
#'
155+
#' The `eps` argument is used to avoid information leakage when computing the
156+
#' censoring probability. Subtracting a small number avoids using data that
157+
#' would not be known at the time of prediction. For example, if we are making
158+
#' survival probability predictions at `eval_time = 3.0`, we would not know the
159+
#' about the probability of being censored at that exact time (since it has not
160+
#' occurred yet).
161+
#'
162+
#' Note that if there are `n` rows in `data` and `t` time points, the resulting
163+
#' data has `n * t` rows. Computations will not easily scale well as `t` becomes
164+
#' large.
165+
#' @references Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999),
166+
#' Assessment and comparison of prognostic classification schemes for survival
167+
#' data. _Statist. Med._, 18: 2529-2545.
168+
#' @export
169+
#' @name censoring_weights
170+
#' @keywords internal
171+
.censoring_weights_graf <- function(object, ...) {
172+
UseMethod(".censoring_weights_graf")
173+
}
174+
175+
#' @export
176+
#' @rdname censoring_weights
177+
.censoring_weights_graf.default <- function(object, ...) {
178+
cls <- paste0("'", class(object), "'", collapse = ", ")
179+
msg <- paste("There is no `.censoring_weights_graf()` method for objects with class(es):",
180+
cls)
181+
rlang::abort(msg)
182+
}
183+
184+
185+
#' @export
186+
#' @rdname censoring_weights
187+
.censoring_weights_graf.workflow <- function(object,
188+
data,
189+
eval_time,
190+
rows = NULL,
191+
predictors = NULL,
192+
trunc = 0.05, eps = 10^-10, ...) {
193+
if (is.null(object$fit$fit)) {
194+
rlang::abort("The workflow does not have a model fit object.", call = FALSE)
195+
}
196+
.censoring_weights_graf(object$fit$fit, data, eval_time, rows, predictors, trunc, eps)
197+
}
198+
199+
#' @export
200+
#' @rdname censoring_weights
201+
.censoring_weights_graf.model_fit <- function(object,
202+
data,
203+
eval_time,
204+
rows = NULL,
205+
predictors = NULL,
206+
trunc = 0.05, eps = 10^-10, ...) {
207+
rlang::check_dots_empty()
208+
.check_censor_model(object)
209+
if (!is.null(predictors)) {
210+
rlang::warn("The 'predictors' argument to the survival weighting function is not currently used.", call = FALSE)
211+
}
212+
eval_time <- .filter_eval_time(eval_time)
213+
214+
truth <- object$preproc$y_var
215+
if (length(truth) != 1) {
216+
# check_outcome() tests that the outcome column is a Surv object
217+
rlang::abort("The event time data should be in a single column with class 'Surv'", call = FALSE)
218+
}
219+
surv_data <- dplyr::select(data, dplyr::all_of(!!truth)) %>% setNames("surv")
220+
.check_censored_right(surv_data$surv)
221+
222+
purrr::map(eval_time,
223+
~ graf_weight_time(surv_data$surv, .x, eps = eps, rows = rows)) %>%
224+
purrr::list_rbind() %>%
225+
dplyr::mutate(
226+
.prob_cens = predict(object$censor_probs, time = weight_time, as_vector = TRUE),
227+
.prob_cens = trunc_probs(.prob_cens, trunc),
228+
.weight_cens = 1 / .prob_cens
229+
) %>%
230+
dplyr::select(.row, eval_time, .prob_cens, .weight_cens)
231+
}
232+
233+
# nocov end

R/parsnip-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ utils::globalVariables(
4444
"compute_intercept", "remove_intercept", "estimate", "term",
4545
"call_info", "component", "component_id", "func", "tunable", "label",
4646
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts",
47-
"protect", "s"
47+
"protect", "weight_time", ".prob_cens", ".weight_cens", "s"
4848
)
4949
)
5050

man/censoring_weights.Rd

Lines changed: 116 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/helper-objects.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,16 @@ caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE)
1111
quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE)
1212

1313
run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0
14+
15+
# ------------------------------------------------------------------------------
16+
# for skips
17+
18+
is_tf_ok <- function() {
19+
tf_ver <- try(tensorflow::tf_version(), silent = TRUE)
20+
if (inherits(tf_ver, "try-error")) {
21+
res <- FALSE
22+
} else {
23+
res <- !is.null(tf_ver)
24+
}
25+
res
26+
}

0 commit comments

Comments
 (0)