Skip to content

Commit 9c2bdbb

Browse files
authored
merge pr #723: error on predict(type = "prob") with outcome level named "class"
2 parents 3871fd8 + 093470b commit 9c2bdbb

File tree

3 files changed

+41
-1
lines changed

3 files changed

+41
-1
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
* An inconsistency for probability type predictions for two-class GAM models was fixed (#708)
1212

13+
* `predict(type = "prob")` will now provide an error if the outcome variable has a level called `"class"` (#720).
14+
1315
# parsnip 0.2.1
1416

1517
* Fixed a major bug in spark models induced in the previous version (#671).

R/predict_classprob.R

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
99
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
1010

1111
check_spec_pred_type(object, "prob")
12-
12+
check_spec_levels(object)
1313

1414
if (inherits(object$fit, "try-error")) {
1515
rlang::warn("Model fit failed; cannot make predictions.")
@@ -48,3 +48,16 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
4848
# @inheritParams predict.model_fit
4949
predict_classprob <- function(object, ...)
5050
UseMethod("predict_classprob")
51+
52+
check_spec_levels <- function(spec) {
53+
if ("class" %in% spec$lvl) {
54+
rlang::abort(
55+
glue::glue(
56+
"The outcome variable `{spec$preproc$y_var}` has a level called 'class'. ",
57+
"This value is reserved for parsnip's classification internals; please ",
58+
"change the levels, perhaps with `forcats::fct_relevel()`."
59+
),
60+
call = NULL
61+
)
62+
}
63+
}

tests/testthat/test_predict_formats.R

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,31 @@ test_that('non-standard levels', {
5656
c("2low", "high+values"))
5757
})
5858

59+
test_that('predict(type = "prob") with level "class" (see #720)', {
60+
x <- tibble::tibble(
61+
boop = factor(sample(c("class", "class_1"), 100, replace = TRUE)),
62+
bop = rnorm(100),
63+
beep = rnorm(100)
64+
)
65+
66+
expect_error(
67+
regexp = NA,
68+
mod <- logistic_reg() %>%
69+
set_mode(mode = "classification") %>%
70+
fit(boop ~ bop + beep, data = x)
71+
)
72+
73+
expect_error(
74+
regexp = NA,
75+
predict(mod, type = "class", new_data = x)
76+
)
77+
78+
expect_error(
79+
regexp = "variable `boop` has a level called 'class'",
80+
predict(mod, type = "prob", new_data = x)
81+
)
82+
})
83+
5984

6085
test_that('non-factor classification', {
6186
skip_if(run_glmnet)

0 commit comments

Comments
 (0)