Skip to content

warn with >2 levels on logistic_reg() fit #916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
* Fixed bug with prediction from a boosted tree model fitted with `"xgboost"` using a custom objective function (#875).

* Several internal functions (to help work with `Surv` objects) were added as a standalone file that can be used in other packages via `usethis::use_standalone("tidymodels/parsnip")`.

* `logistic_reg()` will now warn at `fit()` when the outcome has more than two levels (#545).

# parsnip 1.0.4

Expand Down
11 changes: 11 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,17 @@ check_outcome <- function(y, spec) {
if (!outcome_is_factor) {
rlang::abort("For a classification model, the outcome should be a factor.")
}

if (inherits(spec, "logistic_reg") && is.atomic(y) && length(levels(y)) > 2) {
# warn rather than error since some engines handle this case by binning
# all but the first level as the non-event, so this may be intended
cli::cli_warn(c(
"!" = "Logistic regression is intended for modeling binary outcomes, \\
but there are {length(levels(y))} levels in the outcome.",
"i" = "If this is unintended, adjust outcome levels accordingly or \\
see the {.fn multinom_reg} function."
))
}
}

if (spec$mode == "censored regression") {
Expand Down
14 changes: 14 additions & 0 deletions tests/testthat/_snaps/logistic_reg.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,17 @@
Computational engine: glmnet


# bad input

Code
res <- mtcars %>% dplyr::mutate(cyl = as.factor(cyl)) %>% fit(logistic_reg(),
cyl ~ mpg, data = .)
Condition
Warning:
! Logistic regression is intended for modeling binary outcomes, but there are 3 levels in the outcome.
i If this is unintended, adjust outcome levels accordingly or see the `multinom_reg()` function.
Warning:
glm.fit: algorithm did not converge
Warning:
glm.fit: fitted probabilities numerically 0 or 1 occurred

7 changes: 7 additions & 0 deletions tests/testthat/test_logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ test_that('bad input', {
expect_error(translate(logistic_reg(x = hpc[,1:3], y = hpc$class) %>% set_engine(engine = "glmnet")))
expect_error(translate(logistic_reg(formula = y ~ x) %>% set_engine(engine = "glm")))
expect_error(translate(logistic_reg(mixture = 0.5) %>% set_engine(engine = "LiblineaR")))

expect_snapshot(
res <-
mtcars %>%
dplyr::mutate(cyl = as.factor(cyl)) %>%
fit(logistic_reg(), cyl ~ mpg, data = .)
)
})

# ------------------------------------------------------------------------------
Expand Down