Skip to content

Commit 915e70e

Browse files
authored
Be more verbose about prediction types in check_pred_type() (#899)
* be explicit about which survival-related type * style * `switch()` for speed
1 parent 2bbb111 commit 915e70e

File tree

1 file changed

+25
-10
lines changed

1 file changed

+25
-10
lines changed

R/predict.R

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,6 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
179179
res
180180
}
181181

182-
surv_types <- c("time", "survival", "hazard")
183-
184182
check_pred_type <- function(object, type, ...) {
185183
if (is.null(type)) {
186184
type <-
@@ -197,14 +195,31 @@ check_pred_type <- function(object, type, ...) {
197195
glue_collapse(pred_types, sep = ", ", last = " and ")
198196
)
199197
)
200-
if (type == "numeric" & object$spec$mode != "regression")
201-
rlang::abort("For numeric predictions, the object should be a regression model.")
202-
if (type == "class" & object$spec$mode != "classification")
203-
rlang::abort("For class predictions, the object should be a classification model.")
204-
if (type == "prob" & object$spec$mode != "classification")
205-
rlang::abort("For probability predictions, the object should be a classification model.")
206-
if (type %in% surv_types & object$spec$mode != "censored regression")
207-
rlang::abort("For event time predictions, the object should be a censored regression.")
198+
199+
switch(
200+
type,
201+
"numeric" = if (object$spec$mode != "regression") {
202+
rlang::abort("For numeric predictions, the object should be a regression model.")
203+
},
204+
"class" = if (object$spec$mode != "classification") {
205+
rlang::abort("For class predictions, the object should be a classification model.")
206+
},
207+
"prob" = if (object$spec$mode != "classification") {
208+
rlang::abort("For probability predictions, the object should be a classification model.")
209+
},
210+
"time" = if (object$spec$mode != "censored regression") {
211+
rlang::abort("For event time predictions, the object should be a censored regression.")
212+
},
213+
"survival" = if (object$spec$mode != "censored regression") {
214+
rlang::abort("For survival probability predictions, the object should be a censored regression.")
215+
},
216+
"hazard" = if (object$spec$mode != "censored regression") {
217+
rlang::abort("For hazard predictions, the object should be a censored regression.")
218+
},
219+
"linear_pred" = if (object$spec$mode != "censored regression") {
220+
rlang::abort("For the linear predictor, the object should be a censored regression.")
221+
}
222+
)
208223

209224
# TODO check for ... options when not the correct type
210225
type

0 commit comments

Comments
 (0)