Skip to content

Survival censoring weights #897

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 15, 2023
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.0.4.9001
Version: 1.0.4.9002
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method(.censoring_weights_graf,default)
S3method(.censoring_weights_graf,model_fit)
S3method(.censoring_weights_graf,workflow)
S3method(augment,model_fit)
S3method(extract_fit_engine,model_fit)
S3method(extract_parameter_dials,model_spec)
Expand Down Expand Up @@ -138,6 +141,7 @@ S3method(varying_args,model_spec)
S3method(varying_args,recipe)
S3method(varying_args,step)
export("%>%")
export(.censoring_weights_graf)
export(.check_glmnet_penalty_fit)
export(.check_glmnet_penalty_predict)
export(.cols)
Expand Down
186 changes: 186 additions & 0 deletions R/ipcw.R
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,189 @@ trunc_probs <- function(probs, trunc = 0.01) {
}
eval_time
}

add_dot_row_to_weights <- function(dat, rows = NULL) {
if (is.null(rows)) {
dat <- add_rowindex(dat)
} else {
m <- length(rows)
n <- nrow(dat)
if (m != n) {
rlang::abort(
glue::glue(
"The length of 'rows' ({m}) should be equal to the number of rows in 'data' ({n})"
)
)
}
dat$.row <- rows
}
dat
}

.check_censor_model <- function(x) {
nms <- names(x)
if (!any(nms == "censor_probs")) {
rlang::abort("Please refit the model with parsnip version 1.0.4 or greater.")
}
invisible(NULL)
}

# nocov start
# these are tested in extratests
# ------------------------------------------------------------------------------
# Brier score helpers. Most of this is based off of Graf, E., Schmoor, C.,
# Sauerbrei, W. and Schumacher, M. (1999), Assessment and comparison of
# prognostic classification schemes for survival data. _Statist. Med._, 18:
# 2529-2545.

# We need to use the time of analysis to determine what time to use to evaluate
# the IPCWs.

graf_weight_time <- function(surv_obj, eval_time, rows = NULL, eps = 10^-10) {
event_time <- .extract_surv_time(surv_obj)
status <- .extract_surv_status(surv_obj)
is_event_before_t <- event_time <= eval_time & status == 1
is_censored <- event_time > eval_time

# Three possible contributions to the statistic from Graf 1999

# Censoring time before eval_time, no contribution (Graf category 3)
weight_time <- rep(NA_real_, length(event_time))

# A real event prior to predict time (Graf category 1)
weight_time[is_event_before_t] <- event_time[is_event_before_t] - eps

# Observed time greater than eval_time (Graf category 2)
weight_time[is_censored] <- eval_time - eps

weight_time <- ifelse(weight_time < 0, 0, weight_time)

res <- tibble::tibble(surv = surv_obj, weight_time = weight_time, eval_time)
add_dot_row_to_weights(res, rows)
}

# ------------------------------------------------------------------------------
#' Calculations for inverse probability of censoring weights (IPCW)
#'
#' The method of Graf _et al_ (1999) is used to compute weights at specific
#' evaluation times that can be used to help measure a model's time-dependent
#' performance (e.g. the time-dependent Brier score or the area under the ROC
#' curve).
#' @param data A data frame with a column containing a [survival::Surv()] object.
#' @param predictors Not currently used. A potential future slot for models with
#' informative censoring based on columns in `data`.
#' @param rows An optional integer vector with length equal to the number of
#' rows in `data` that is used to index the original data. The default is to
#' use a fresh index on data (i.e. `1:nrow(data)`).
#' @param eval_time A vector of non-negative times at which we should
#' compute the probability of censoring and the corresponding weights.
#' @param object A fitted parsnip model object or fitted workflow with a mode
#' of "censored regression".
#' @param trunc A potential lower bound for the probability of censoring to avoid
#' very large weight values.
#' @param eps A small value that is subtracted from the evaluation time when
#' computing the censoring probabilities. In doing so, the censoring probability
#' prediction avoids information leakage by avoiding data that would not be
#' known at the time of prediction.
#' @return A tibble with columns `.row`, `eval_time`, `.prob_cens` (the
#' probability of being censored just prior to the evaluation time), and
#' `.weight_cens` (the inverse probability of censoring weight).
#' @details
#'
#' A probability that the data are censored immediately prior to a specific
#' time is computed. To do this, we must determine what time to
#' make the prediction. There are two time values for each row of the data set:
#' the observed time (either censored or not) and the time that the model is
#' being evaluated at (e.g. the survival function prediction at some time point),
#' which is constant across rows. .
#'
#' From Graf _et al_ (1999) there are three cases:
#'
#' - If the observed time is a censoring time and that is before the
#' evaluation time, the data point should make no contribution to the
#' performance metric (their "category 3"). These values have a missing
#' value for their probability estimate (and also for their weight column).
#'
#' - If the observed time corresponds to an actual event, and that time is
#' prior to the evaluation time (category 1), the probability of being
#' censored is predicted at the observed time (minus an epsilon).
#'
#' - If the observed time corresponds to an actual event, and it is _after_
#' the evaluation time (category 2), the probability of being
#' censored is predicted at the evaluation time (minus an epsilon).
#'
#' The epsilon is used since, we would not have actual information at time `t`
#' for a data point being predicted at time `t` (only data prior to time `t`
#' should be available).
#'
#' After the censoring probability is computed, the `trunc` option is used to
#' avoid using numbers pathologically close to zero. After this, the weight is
#' computed by inverting the censoring probability.
#'
#' Note that if there are `n` rows in `data` and `t` time points, the resulting
#' data has `n * t` rows. Computations will not easily scale well as `t` becomes
#' large.
#' @references Graf, E., Schmoor, C., Sauerbrei, W. and Schumacher, M. (1999),
#' Assessment and comparison of prognostic classification schemes for survival
#' data. _Statist. Med._, 18: 2529-2545.
#' @export
#' @name censoring_weights
#' @keywords internal
.censoring_weights_graf <- function(object, ...) {
UseMethod(".censoring_weights_graf")
}

#' @export
#' @rdname censoring_weights
.censoring_weights_graf.default <- function(object, ...) {
cls <- paste0("'", class(object), "'", collapse = ", ")
msg <- paste("There are no `.censoring_weights_graf` for objects with class(es):",
cls)
rlang::abort(msg)
}


#' @export
#' @rdname censoring_weights
.censoring_weights_graf.workflow <- function(object,
data,
eval_time,
rows = NULL,
predictors = NULL,
trunc = 0.05, eps = 10^-10, ...) {
if (is.null(object$fit$fit)) {
rlang::abort("The workflow does not have a model fit object.", call = FALSE)
}
.censoring_weights_graf(object$fit$fit, data, eval_time, rows, predictors, trunc, eps)
}

#' @export
#' @rdname censoring_weights
.censoring_weights_graf.model_fit <- function(object,
data,
eval_time,
rows = NULL,
predictors = NULL,
trunc = 0.05, eps = 10^-10, ...) {
rlang::check_dots_empty()
.check_censor_model(object)
if (!is.null(predictors)) {
rlang::warn("The 'predictors' argument to the survival weighting function is not currently used.", call = FALSE)
}
eval_time <- .filter_eval_time(eval_time) # TODO maybe this should be the check function

truth <- object$preproc$y_var
surv_data <- dplyr::select(data, dplyr::all_of(!!truth)) %>% setNames("surv")
.check_censored_right(surv_data$surv)

purrr::map_dfr(eval_time,
~ graf_weight_time(surv_data$surv, .x, eps = eps, rows = rows)) %>%
dplyr::mutate(
.prob_cens = predict(object$censor_probs, time = weight_time, as_vector = TRUE),
.prob_cens = trunc_probs(.prob_cens, trunc),
.weight_cens = 1 / .prob_cens
) %>%
dplyr::select(.row, eval_time, .prob_cens, .weight_cens)
}

# nocov end
3 changes: 2 additions & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ utils::globalVariables(
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
"compute_intercept", "remove_intercept", "estimate", "term",
"call_info", "component", "component_id", "func", "tunable", "label",
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts", "protect"
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts",
"protect", "weight_time", ".prob_cens", ".weight_cens"
)
)

Expand Down
111 changes: 111 additions & 0 deletions man/censoring_weights.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 13 additions & 0 deletions tests/testthat/helper-objects.R
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,16 @@ caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE)
quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE)

run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0

# ------------------------------------------------------------------------------
# for skips

is_tf_ok <- function() {
tf_ver <- try(tensorflow::tf_version(), silent = TRUE)
if (inherits(tf_ver, "try-error")) {
res <- FALSE
} else {
res <- !is.null(tf_ver)
}
res
}
16 changes: 13 additions & 3 deletions tests/testthat/test-ipcw.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ test_that('probability truncation', {
)
})


test_that('time filtering', {
times_1 <- 0:10
times_2 <- c(Inf, NA, -3, times_1, times_1)
Expand All @@ -36,7 +35,18 @@ test_that('time filtering', {
expect_null(parsnip:::.filter_eval_time(NULL))
})

test_that('probability truncation', {
probs_1 <- (0:10) / 20
probs_2 <- probs_1
probs_2[3] <- NA_real_

expect_equal(parsnip:::trunc_probs(probs_1, 0), probs_1)
expect_equal(parsnip:::trunc_probs(probs_2, 0), probs_2)
expect_equal(
parsnip:::trunc_probs(probs_1, 0.1),
ifelse(probs_1 < 0.05 / 2, 0.05 / 2, probs_1)
)
expect_equal(min(parsnip:::trunc_probs(probs_2, 0.1), na.rm = TRUE), 0.05 / 2)
expect_equal(is.na(parsnip:::trunc_probs(probs_2, 0.1)),is.na(probs_2))



})
Loading