Skip to content

Commit 9dcfcf0

Browse files
authored
warn with >2 levels on logistic_reg() fit (#916)
1 parent 332de18 commit 9dcfcf0

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* Fixed bug with prediction from a boosted tree model fitted with `"xgboost"` using a custom objective function (#875).
1616

1717
* Several internal functions (to help work with `Surv` objects) were added as a standalone file that can be used in other packages via `usethis::use_standalone("tidymodels/parsnip")`.
18-
18+
* `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545).
1919

2020
# parsnip 1.0.4
2121

R/misc.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,17 @@ check_outcome <- function(y, spec) {
345345
if (!outcome_is_factor) {
346346
rlang::abort("For a classification model, the outcome should be a factor.")
347347
}
348+
349+
if (inherits(spec, "logistic_reg") && is.atomic(y) && length(levels(y)) > 2) {
350+
# warn rather than error since some engines handle this case by binning
351+
# all but the first level as the non-event, so this may be intended
352+
cli::cli_warn(c(
353+
"!" = "Logistic regression is intended for modeling binary outcomes, \\
354+
but there are {length(levels(y))} levels in the outcome.",
355+
"i" = "If this is unintended, adjust outcome levels accordingly or \\
356+
see the {.fn multinom_reg} function."
357+
))
358+
}
348359
}
349360

350361
if (spec$mode == "censored regression") {

tests/testthat/_snaps/logistic_reg.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,17 @@
1515
Computational engine: glmnet
1616
1717

18+
# bad input
19+
20+
Code
21+
res <- mtcars %>% dplyr::mutate(cyl = as.factor(cyl)) %>% fit(logistic_reg(),
22+
cyl ~ mpg, data = .)
23+
Condition
24+
Warning:
25+
! Logistic regression is intended for modeling binary outcomes, but there are 3 levels in the outcome.
26+
i If this is unintended, adjust outcome levels accordingly or see the `multinom_reg()` function.
27+
Warning:
28+
glm.fit: algorithm did not converge
29+
Warning:
30+
glm.fit: fitted probabilities numerically 0 or 1 occurred
31+

tests/testthat/test_logistic_reg.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@ test_that('bad input', {
1616
expect_error(translate(logistic_reg(x = hpc[,1:3], y = hpc$class) %>% set_engine(engine = "glmnet")))
1717
expect_error(translate(logistic_reg(formula = y ~ x) %>% set_engine(engine = "glm")))
1818
expect_error(translate(logistic_reg(mixture = 0.5) %>% set_engine(engine = "LiblineaR")))
19+
20+
expect_snapshot(
21+
res <-
22+
mtcars %>%
23+
dplyr::mutate(cyl = as.factor(cyl)) %>%
24+
fit(logistic_reg(), cyl ~ mpg, data = .)
25+
)
1926
})
2027

2128
# ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)