Skip to content

Use cli instead of rlang for abort and warn, Fixes #1141 #1154

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions R/predict_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,23 @@
#' @export predict_class.model_fit
#' @export
predict_class.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "classification")
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
if (object$spec$mode != "classification") {
cli::cli_abort("{.fun predict.model_fit} is for predicting factor outcomes.")
}

check_spec_pred_type(object, "class")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$class$pre))
if (!is.null(object$spec$method$pred$class$pre)) {
new_data <- object$spec$method$pred$class$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$class)
Expand Down Expand Up @@ -56,6 +58,6 @@ predict_class.model_fit <- function(object, new_data, ...) {
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_class <- function(object, ...)
predict_class <- function(object, ...) {
UseMethod("predict_class")

}
35 changes: 20 additions & 15 deletions R/predict_classprob.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,24 @@
#' @export predict_classprob.model_fit
#' @export
predict_classprob.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "classification")
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
if (object$spec$mode != "classification") {
cli::cli_abort("{.fun predict.model_fit()} is for predicting factor outcomes.")
}

check_spec_pred_type(object, "prob")
check_spec_levels(object)

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$prob$pre))
if (!is.null(object$spec$method$pred$prob$pre)) {
new_data <- object$spec$method$pred$prob$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$prob)
Expand All @@ -33,11 +35,13 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
}

# check and sort names
if (!is.data.frame(res) & !inherits(res, "tbl_spark"))
rlang::abort("The was a problem with the probability predictions.")
if (!is.data.frame(res) & !inherits(res, "tbl_spark")) {
cli::cli_abort("The was a problem with the probability predictions.")
}

if (!is_tibble(res) & !inherits(res, "tbl_spark"))
if (!is_tibble(res) & !inherits(res, "tbl_spark")) {
res <- as_tibble(res)
}

res
}
Expand All @@ -46,18 +50,19 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
# @keywords internal
# @rdname other_predict
# @inheritParams predict.model_fit
predict_classprob <- function(object, ...)
predict_classprob <- function(object, ...) {
UseMethod("predict_classprob")
}

check_spec_levels <- function(spec) {
if ("class" %in% spec$lvl) {
rlang::abort(
glue::glue(
"The outcome variable `{spec$preproc$y_var}` has a level called 'class'. ",
"This value is reserved for parsnip's classification internals; please ",
"change the levels, perhaps with `forcats::fct_relevel()`."
),
call = NULL
cli::cli_abort(
c(
"The outcome variable {.var {spec$preproc$y_var}} has a level called {.val class}.",
"i" = "This value is reserved for parsnip's classification internals; please
change the levels, perhaps with {.fn forcats::fct_relevel}.",
call = NULL
)
)
}
}
26 changes: 17 additions & 9 deletions R/predict_numeric.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,35 @@
#' @export predict_numeric.model_fit
#' @export
predict_numeric.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "regression")
rlang::abort(glue::glue("`predict_numeric()` is for predicting numeric outcomes. ",
"Use `predict_class()` or `predict_classprob()` for ",
"classification models."))
if (object$spec$mode != "regression") {
cli::cli_abort(
c(
"{.fun predict_numeric} is for predicting numeric outcomes.",
"i" = "Use {.fun predict_class} or {.fun predict_classprob} for
classification models."
)
)
}

check_spec_pred_type(object, "numeric")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$numeric$pre))
if (!is.null(object$spec$method$pred$numeric$pre)) {
new_data <- object$spec$method$pred$numeric$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$numeric)

res <- eval_tidy(pred_call)

# post-process the predictions
if (!is.null(object$spec$method$pred$numeric$post)) {
res <- object$spec$method$pred$numeric$post(res, object)
Expand All @@ -36,8 +42,9 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
if (is.vector(res)) {
res <- unname(res)
} else {
if (!inherits(res, "tbl_spark"))
if (!inherits(res, "tbl_spark")) {
res <- as.data.frame(res)
}
}
res
}
Expand All @@ -47,5 +54,6 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict_numeric.model_fit
predict_numeric <- function(object, ...)
predict_numeric <- function(object, ...) {
UseMethod("predict_numeric")
}
23 changes: 15 additions & 8 deletions R/predict_time.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,35 @@
#' @export predict_time.model_fit
#' @export
predict_time.model_fit <- function(object, new_data, ...) {
if (object$spec$mode != "censored regression")
rlang::abort(glue::glue("`predict_time()` is for predicting time outcomes. ",
"Use `predict_class()` or `predict_classprob()` for ",
"classification models."))
if (object$spec$mode != "censored regression") {
cli::cli_abort(
c(
"{.fun predict_time} is for predicting time outcomes.",
"i" = "Use {.fun predict_class} or {.fun predict_classprob} for
classification models."
)
)
}

check_spec_pred_type(object, "time")

if (inherits(object$fit, "try-error")) {
rlang::warn("Model fit failed; cannot make predictions.")
cli::cli_warn("Model fit failed; cannot make predictions.")
return(NULL)
}

new_data <- prepare_data(object, new_data)

# preprocess data
if (!is.null(object$spec$method$pred$time$pre))
if (!is.null(object$spec$method$pred$time$pre)) {
new_data <- object$spec$method$pred$time$pre(new_data, object)
}

# create prediction call
pred_call <- make_pred_call(object$spec$method$pred$time)

res <- eval_tidy(pred_call)

# post-process the predictions
if (!is.null(object$spec$method$pred$time$post)) {
res <- object$spec$method$pred$time$post(res, object)
Expand All @@ -45,5 +51,6 @@ predict_time.model_fit <- function(object, new_data, ...) {
#' @keywords internal
#' @rdname other_predict
#' @inheritParams predict_time.model_fit
predict_time <- function(object, ...)
predict_time <- function(object, ...) {
UseMethod("predict_time")
}
2 changes: 1 addition & 1 deletion tests/testthat/test-predict_formats.R
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ test_that('predict(type = "prob") with level "class" (see #720)', {
)

expect_error(
regexp = "variable `boop` has a level called 'class'",
regexp = 'variable `boop` has a level called "class"',
predict(mod, type = "prob", new_data = x)
)
})
Expand Down
Loading