Skip to content

Commit 146bd6b

Browse files
shum461simonpcouch
andauthored
transition to cli from rlang in predict() source (#1148)
------ Co-authored-by: Simon P. Couch <[email protected]>
1 parent d68b765 commit 146bd6b

File tree

1 file changed

+73
-48
lines changed

1 file changed

+73
-48
lines changed

R/predict.R

Lines changed: 73 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@
147147
#' @export
148148
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
149149
if (inherits(object$fit, "try-error")) {
150-
rlang::warn("Model fit failed; cannot make predictions.")
150+
cli::cli_warn("Model fit failed; cannot make predictions.")
151151
return(NULL)
152152
}
153153

@@ -156,7 +156,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
156156

157157
type <- check_pred_type(object, type)
158158
if (type != "raw" && length(opts) > 0) {
159-
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
159+
cli::cli_warn("{.arg opts} is only used with `type = 'raw'` and was ignored.")
160160
}
161161
check_pred_type_dots(object, type, ...)
162162

@@ -173,7 +173,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
173173
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
174174
hazard = predict_hazard(object = object, new_data = new_data, ...),
175175
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
176-
rlang::abort(glue::glue("I don't know about type = '{type}'"))
176+
cli::cli_abort("Unknown prediction {.arg type} '{type}'.")
177177
)
178178
if (!inherits(res, "tbl_spark")) {
179179
res <- switch(
@@ -191,45 +191,69 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
191191
res
192192
}
193193

194-
check_pred_type <- function(object, type, ...) {
194+
check_pred_type <- function(object, type, ..., call = rlang::caller_env()) {
195195
if (is.null(type)) {
196196
type <-
197-
switch(object$spec$mode,
198-
regression = "numeric",
199-
classification = "class",
200-
"censored regression" = "time",
201-
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
197+
switch(
198+
object$spec$mode,
199+
regression = "numeric",
200+
classification = "class",
201+
"censored regression" = "time",
202+
cli::cli_abort(
203+
"{.arg type} should be 'regression', 'censored regression', or 'classification'.",
204+
call = call
205+
)
206+
)
202207
}
203208
if (!(type %in% pred_types))
204-
rlang::abort(
205-
glue::glue(
206-
"`type` should be one of: ",
207-
glue_collapse(pred_types, sep = ", ", last = " and ")
208-
)
209+
cli::cli_abort(
210+
"{.arg type} should be one of:{.arg {pred_types}}",
211+
call = call
209212
)
210213

211214
switch(
212215
type,
213216
"numeric" = if (object$spec$mode != "regression") {
214-
rlang::abort("For numeric predictions, the object should be a regression model.")
217+
cli::cli_abort(
218+
"For numeric predictions, the object should be a regression model.",
219+
call = call
220+
)
215221
},
216222
"class" = if (object$spec$mode != "classification") {
217-
rlang::abort("For class predictions, the object should be a classification model.")
223+
cli::cli_abort(
224+
"For class predictions, the object should be a classification model.",
225+
call = call
226+
)
218227
},
219228
"prob" = if (object$spec$mode != "classification") {
220-
rlang::abort("For probability predictions, the object should be a classification model.")
229+
cli::cli_abort(
230+
"For probability predictions, the object should be a classification model.",
231+
call = call
232+
)
221233
},
222234
"time" = if (object$spec$mode != "censored regression") {
223-
rlang::abort("For event time predictions, the object should be a censored regression.")
235+
cli::cli_abort(
236+
"For event time predictions, the object should be a censored regression.",
237+
call = call
238+
)
224239
},
225240
"survival" = if (object$spec$mode != "censored regression") {
226-
rlang::abort("For survival probability predictions, the object should be a censored regression.")
241+
cli::cli_abort(
242+
"For survival probability predictions, the object should be a censored regression.",
243+
call = call
244+
)
227245
},
228246
"hazard" = if (object$spec$mode != "censored regression") {
229-
rlang::abort("For hazard predictions, the object should be a censored regression.")
247+
cli::cli_abort(
248+
"For hazard predictions, the object should be a censored regression.",
249+
call = call
250+
)
230251
},
231252
"linear_pred" = if (object$spec$mode != "censored regression") {
232-
rlang::abort("For the linear predictor, the object should be a censored regression.")
253+
cli::cli_abort(
254+
"For the linear predictor, the object should be a censored regression.",
255+
call = call
256+
)
233257
}
234258
)
235259

@@ -349,56 +373,57 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
349373

350374
other_args <- c("interval", "level", "std_error", "quantile",
351375
"time", "eval_time", "increasing")
376+
377+
eval_time_types <- c("survival", "hazard")
378+
352379
is_pred_arg <- names(the_dots) %in% other_args
353380
if (any(!is_pred_arg)) {
354381
bad_args <- names(the_dots)[!is_pred_arg]
355382
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
356-
rlang::abort(
357-
glue::glue(
358-
"The ellipses are not used to pass args to the model function's ",
359-
"predict function. These arguments cannot be used: {bad_args}",
360-
)
383+
cli::cli_abort(
384+
"The ellipses are not used to pass args to the model function's
385+
predict function. These arguments cannot be used: {.val bad_args}",
386+
call = call
361387
)
362388
}
363389

364390
# ----------------------------------------------------------------------------
365391
# places where eval_time should not be given
366392
if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) {
367-
rlang::abort(
368-
paste(
369-
"`eval_time` should only be passed to `predict()` when `type` is one of:",
370-
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
371-
)
372-
)
393+
cli::cli_abort(
394+
"{.arg eval_time} should only be passed to {.fn predict} when \\
395+
{.arg type} is one of {.or {.val {eval_time_types}}}.",
396+
call = call
397+
)
398+
399+
373400
}
374401
if (any(nms == "time") & !type %in% c("survival", "hazard")) {
375-
rlang::abort(
376-
paste(
377-
"'time' should only be passed to `predict()` when 'type' is one of:",
378-
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
379-
)
402+
cli::cli_abort(
403+
"{.arg time} should only be passed to {.fn predict} when {.arg type} is
404+
one of {.or {.val {eval_time_types}}}.",
405+
call = call
380406
)
381407
}
382408
# when eval_time should be passed
383409
if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) {
384-
rlang::abort(
385-
paste(
386-
"When using `type` values of 'survival' or 'hazard',",
387-
"a numeric vector `eval_time` should also be given."
388-
)
389-
)
410+
cli::cli_abort(
411+
"When using {.arg type} values of {.or {.val {eval_time_types}}} a numeric
412+
vector {.arg eval_time} should also be given.",
413+
call = call
414+
)
390415
}
391416

392417
# `increasing` only applies to linear_pred for censored regression
393418
if (any(nms == "increasing") &
394419
!(type == "linear_pred" &
395420
object$spec$mode == "censored regression")) {
396-
rlang::abort(
397-
paste(
398-
"The 'increasing' argument only applies to predictions of",
399-
"type 'linear_pred' for the mode censored regression."
400-
)
421+
cli::cli_abort(
422+
"{.arg increasing} only applies to predictions of
423+
type 'linear_pred' for the mode censored regression.",
424+
call = call
401425
)
426+
402427
}
403428

404429
invisible(TRUE)

0 commit comments

Comments
 (0)