Skip to content

Commit 972ccc0

Browse files
committed
updates based on tidymodels/parsnip#937
1 parent d16c76a commit 972ccc0

File tree

2 files changed

+12
-22
lines changed

2 files changed

+12
-22
lines changed

R/grid_code_paths.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ tune_grid_loop_impl <- function(fn_tune_grid_loop_iter,
180180
metrics = metrics,
181181
control = control,
182182
eval_time = eval_time,
183-
seed = seed
183+
seed = seed,
184184
metrics_info = metrics_info,
185185
params = params
186186
)
@@ -221,7 +221,7 @@ tune_grid_loop_impl <- function(fn_tune_grid_loop_iter,
221221
metrics = metrics,
222222
control = control,
223223
eval_time = eval_time,
224-
seed = seed
224+
seed = seed,
225225
metrics_info = metrics_info,
226226
params = params
227227
)
@@ -421,7 +421,7 @@ tune_grid_loop_iter <- function(split,
421421
iter_msg_predictions <- paste(iter_msg_model, "(predictions)")
422422

423423
iter_predictions <- .catch_and_log(
424-
predict_model(split, workflow, iter_grid, metrics, iter_submodels,
424+
predict_model(split, workflow, iter_grid, metrics, iter_submodels,
425425
metrics_info = metrics_info, eval_time = eval_time),
426426
control,
427427
split,
@@ -480,7 +480,7 @@ tune_grid_loop_iter_safely <- function(fn_tune_grid_loop_iter,
480480
seed,
481481
metrics_info,
482482
params) {
483-
483+
484484
fn_tune_grid_loop_iter_wrapper <- super_safely(fn_tune_grid_loop_iter)
485485

486486
# Likely want to debug with `debugonce(tune_grid_loop_iter)`

R/grid_helpers.R

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
2+
predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
33
metrics_info, eval_time = NULL) {
44

55
model <- extract_fit_parsnip(workflow)
@@ -102,31 +102,21 @@ predict_model <- function(split, workflow, grid, metrics, submodels = NULL,
102102
res <- dplyr::full_join(res, case_weights, by = ".row")
103103
}
104104
}
105+
106+
107+
res <- maybe_add_ipcw(res, model, types)
108+
105109
if (!tibble::is_tibble(res)) {
106110
res <- tibble::as_tibble(res)
107111
}
108-
# maybe_add_ipcw(res, model, eval_time, types)
109112
res
110113
}
111114

112-
# TODO do we need this?
113-
# res <- tibble::as_tibble(res)
114-
115-
}
116-
117-
maybe_add_ipcw <- function(.data, model, eval_time, types) {
115+
maybe_add_ipcw <- function(.data, model, types) {
118116
if (!any(types == "survival")) {
119117
return(.data)
120118
}
121-
res <-
122-
tidyr::unnest(.data, cols = .pred) %>%
123-
dplyr::rename(eval_time = .time) %>%
124-
dplyr::full_join(
125-
# TODO is the outcome name enforced or the original name?
126-
parsnip::.censoring_weights_graf(model, .data, eval_time = eval_time),
127-
by = c(".row", "eval_time")
128-
)
129-
res
119+
parsnip::.censoring_weights_graf(model, .data)
130120
}
131121

132122
predict_wrapper <- function(model, new_data, type, eval_time, subgrid = NULL) {
@@ -147,7 +137,7 @@ predict_wrapper <- function(model, new_data, type, eval_time, subgrid = NULL) {
147137
# Add in censored regression evaluation times (if needed)
148138
has_type <- type %in% c("survival", "hazard")
149139
if (model$spec$mode == "censored regression" & !is.null(eval_time) & has_type) {
150-
cl <- rlang::call_modify(cl, time = eval_time)
140+
cl <- rlang::call_modify(cl, eval_time = eval_time)
151141
}
152142

153143
# When there are sub-models:

0 commit comments

Comments
 (0)