Skip to content

updated rlang messages in the predict functions to use cli. #1148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 73 additions & 48 deletions R/predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
#' @export
predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...) {
if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

Expand All @@ -156,7 +156,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)

type <- check_pred_type(object, type)
if (type != "raw" && length(opts) > 0) {
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
cli::cli_warn("{.arg opts} is only used with `type = 'raw'` and was ignored.")
}
check_pred_type_dots(object, type, ...)

Expand All @@ -173,7 +173,7 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
linear_pred = predict_linear_pred(object = object, new_data = new_data, ...),
hazard = predict_hazard(object = object, new_data = new_data, ...),
raw = predict_raw(object = object, new_data = new_data, opts = opts, ...),
rlang::abort(glue::glue("I don't know about type = '{type}'"))
cli::cli_abort("Unknown prediction {.arg type} '{type}'.")
)
if (!inherits(res, "tbl_spark")) {
res <- switch(
Expand All @@ -191,45 +191,69 @@ predict.model_fit <- function(object, new_data, type = NULL, opts = list(), ...)
res
}

check_pred_type <- function(object, type, ...) {
check_pred_type <- function(object, type, ..., call = rlang::caller_env()) {
if (is.null(type)) {
type <-
switch(object$spec$mode,
regression = "numeric",
classification = "class",
"censored regression" = "time",
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'."))
switch(
object$spec$mode,
regression = "numeric",
classification = "class",
"censored regression" = "time",
cli::cli_abort(
"{.arg type} should be 'regression', 'censored regression', or 'classification'.",
call = call
)
)
}
if (!(type %in% pred_types))
rlang::abort(
glue::glue(
"`type` should be one of: ",
glue_collapse(pred_types, sep = ", ", last = " and ")
)
cli::cli_abort(
"{.arg type} should be one of:{.arg {pred_types}}",
call = call
)

switch(
type,
"numeric" = if (object$spec$mode != "regression") {
rlang::abort("For numeric predictions, the object should be a regression model.")
cli::cli_abort(
"For numeric predictions, the object should be a regression model.",
call = call
)
},
"class" = if (object$spec$mode != "classification") {
rlang::abort("For class predictions, the object should be a classification model.")
cli::cli_abort(
"For class predictions, the object should be a classification model.",
call = call
)
},
"prob" = if (object$spec$mode != "classification") {
rlang::abort("For probability predictions, the object should be a classification model.")
cli::cli_abort(
"For probability predictions, the object should be a classification model.",
call = call
)
},
"time" = if (object$spec$mode != "censored regression") {
rlang::abort("For event time predictions, the object should be a censored regression.")
cli::cli_abort(
"For event time predictions, the object should be a censored regression.",
call = call
)
},
"survival" = if (object$spec$mode != "censored regression") {
rlang::abort("For survival probability predictions, the object should be a censored regression.")
cli::cli_abort(
"For survival probability predictions, the object should be a censored regression.",
call = call
)
},
"hazard" = if (object$spec$mode != "censored regression") {
rlang::abort("For hazard predictions, the object should be a censored regression.")
cli::cli_abort(
"For hazard predictions, the object should be a censored regression.",
call = call
)
},
"linear_pred" = if (object$spec$mode != "censored regression") {
rlang::abort("For the linear predictor, the object should be a censored regression.")
cli::cli_abort(
"For the linear predictor, the object should be a censored regression.",
call = call
)
}
)

Expand Down Expand Up @@ -349,56 +373,57 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())

other_args <- c("interval", "level", "std_error", "quantile",
"time", "eval_time", "increasing")

eval_time_types <- c("survival", "hazard")

is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
cli::cli_abort(
"The ellipses are not used to pass args to the model function's
predict function. These arguments cannot be used: {.val bad_args}",
call = call
)
}

# ----------------------------------------------------------------------------
# places where eval_time should not be given
if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"`eval_time` should only be passed to `predict()` when `type` is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
cli::cli_abort(
"{.arg eval_time} should only be passed to {.fn predict} when \\
{.arg type} is one of {.or {.val {eval_time_types}}}.",
call = call
)


}
if (any(nms == "time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"'time' should only be passed to `predict()` when 'type' is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
cli::cli_abort(
"{.arg time} should only be passed to {.fn predict} when {.arg type} is
one of {.or {.val {eval_time_types}}}.",
call = call
)
}
# when eval_time should be passed
if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"When using `type` values of 'survival' or 'hazard',",
"a numeric vector `eval_time` should also be given."
)
)
cli::cli_abort(
"When using {.arg type} values of {.or {.val {eval_time_types}}} a numeric
vector {.arg eval_time} should also be given.",
call = call
)
}

# `increasing` only applies to linear_pred for censored regression
if (any(nms == "increasing") &
!(type == "linear_pred" &
object$spec$mode == "censored regression")) {
rlang::abort(
paste(
"The 'increasing' argument only applies to predictions of",
"type 'linear_pred' for the mode censored regression."
)
cli::cli_abort(
"{.arg increasing} only applies to predictions of
type 'linear_pred' for the mode censored regression.",
call = call
)

}

invisible(TRUE)
Expand Down
Loading