Skip to content

Commit fc351d7

Browse files
committed
add ipcw and initial performance function
1 parent 7706503 commit fc351d7

File tree

2 files changed

+41
-8
lines changed

2 files changed

+41
-8
lines changed

R/grid_helpers.R

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,12 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
88
x_vals <- forged$predictors
99
y_vals <- forged$outcomes
1010

11+
# TODO patch since parsnip does not record the column names when Surv objects
12+
# are used with fit_xy()
13+
if (model$spec$mode == "censored regression") {
14+
model$preproc$y_var <- names(y_vals)
15+
}
16+
1117
orig_rows <- as.integer(split, data = "assessment")
1218

1319
if (length(orig_rows) != nrow(x_vals)) {
@@ -92,9 +98,28 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
9298
}
9399
}
94100

95-
tibble::as_tibble(res)
101+
# TODO do we need this?
102+
# res <- tibble::as_tibble(res)
103+
maybe_add_ipcw(res, model, eval_time, types)
96104
}
97105

106+
maybe_add_ipcw <- function(.data, model, eval_time, types) {
107+
if (!any(types == "survival")) {
108+
return(.data)
109+
}
110+
res <-
111+
tidyr::unnest(.data, cols = .pred) %>%
112+
dplyr::rename(eval_time = .time) %>%
113+
dplyr::full_join(
114+
# TODO is the outcome name enforced or the original name?
115+
parsnip::.censoring_weights_graf(model, .data, eval_time = eval_time),
116+
by = c(".row", "eval_time")
117+
)
118+
res
119+
}
120+
121+
122+
98123
predict_wrapper <- function(model, new_data, type, eval_time, subgrid = NULL) {
99124
if (is.null(subgrid)) {
100125
fn <- "predict.model_fit"

R/grid_performance.R

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ metrics_info <- function(x) {
8383
} else if (all(types == "class" | types == "prob")) {
8484
estimate_class_prob(dat, metric, param_names, outcome_name, case_weights, types, event_level)
8585
} else if (all(types == "time" | types == "survival")) {
86-
# estimate_surv(dat, metric, param_names, outcome_name, case_weights, types, eval_time)
86+
estimate_surv(dat, metric, param_names, outcome_name, case_weights, types)
8787
} else {
8888
rlang::abort("Metric type not yet supported by tune.")
8989
}
@@ -133,14 +133,22 @@ estimate_class_prob <- function(dat, metric, param_names, outcome_name,
133133
)
134134
}
135135

136-
estimate_surv <- function(dat, metric, param_names, outcome_name, case_weights, eval_time) {
137-
# IPCW should already be computed, un-nested and have .time
138-
types <- NULL
136+
estimate_surv <- function(dat, metric, param_names, outcome_name, case_weights, types) {
137+
# TODO mixed sets?
139138
if (any(types == "survival")) {
140-
139+
res <-
140+
dat %>%
141+
dplyr::group_by(!!!rlang::syms(param_names), eval_time) %>%
142+
metric(
143+
truth = surv,
144+
estimate = .pred_survival,
145+
censoring_weights = .weight_cens,
146+
case_weights = !!case_weights,
147+
eval_time = eval_time
148+
)
141149
} else {
142-
# pad with .time = NA
150+
# pad with .time = NA?
143151
}
144-
# metric(estimate = .pred, truth = !!sym(outcome_name), case_weights = !!case_weights)
152+
res
145153
}
146154

0 commit comments

Comments
 (0)