Skip to content

Commit 82cc19a

Browse files
committed
updates based on tidymodels/parsnip#937
1 parent 972ccc0 commit 82cc19a

File tree

2 files changed

+10
-19
lines changed

2 files changed

+10
-19
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Imports:
2727
GPfit,
2828
hardhat (>= 1.2.0),
2929
lifecycle (>= 1.0.0),
30-
parsnip (>= 1.0.4.9004),
30+
parsnip (>= 1.0.4.9006),
3131
purrr (>= 1.0.0),
3232
recipes (>= 1.0.4),
3333
rlang (>= 1.0.6.9000),

R/grid_performance.R

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ metrics_info <- function(x) {
6161
# append_metrics(). <many times>
6262
# .estimate_metrics()
6363

64-
# predictions made in predict_model()
65-
64+
# predictions made in predict_model()
65+
6666
if (inherits(dat, "try-error")) {
6767
return(NULL)
6868
}
@@ -139,21 +139,12 @@ estimate_class_prob <- function(dat, metric, param_names, outcome_name,
139139
}
140140

141141
estimate_surv <- function(dat, metric, param_names, outcome_name, case_weights, types) {
142-
# TODO mixed sets?
143-
if (any(types == "survival")) {
144-
res <-
145-
dat %>%
146-
dplyr::group_by(!!!rlang::syms(param_names), eval_time) %>%
147-
metric(
148-
truth = surv,
149-
estimate = .pred_survival,
150-
censoring_weights = .weight_cens,
151-
case_weights = !!case_weights,
152-
eval_time = eval_time
153-
)
154-
} else {
155-
# pad with .time = NA?
156-
}
157-
res
142+
dat %>%
143+
dplyr::group_by(!!!rlang::syms(param_names)) %>%
144+
metric(
145+
truth = surv,
146+
estimate = .pred,
147+
case_weights = !!case_weights
148+
)
158149
}
159150

0 commit comments

Comments
 (0)