Skip to content

Commit b6db676

Browse files
authored
Use check_outcome for all fit paths (#625)
* Use `check_outcome` for all fit paths * Update NEWS
1 parent b19f0e7 commit b6db676

File tree

3 files changed

+15
-17
lines changed

3 files changed

+15
-17
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
* The list column produced when creating survival probability predictions is now always called `.pred` (with `.pred_survival` being used inside of the list column).
1414

15+
* Fixed outcome type checking affecting a subset of regression models (#625).
16+
1517
## Other Changes
1618

1719
* When the xy interface is used and the underlying model expects to use a matrix, a better warning is issued when predictors contain non-numeric columns (including dates).

R/fit_helpers.R

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,10 @@
66
form_form <-
77
function(object, control, env, ...) {
88

9+
check_outcome(eval_tidy(env$formula[[2]], env$data), object)
10+
911
# prob rewrite this as simple subset/levels
1012
y_levels <- levels_from_formula(env$formula, env$data)
11-
12-
if (object$mode == "classification") {
13-
if (!inherits(env$data, "tbl_spark") && is.null(y_levels))
14-
rlang::abort("For a classification model, the outcome should be a factor.")
15-
} else if (object$mode == "regression") {
16-
if (!inherits(env$data, "tbl_spark") && !is.null(y_levels))
17-
rlang::abort("For a regression model, the outcome should be numeric.")
18-
}
19-
2013
object <- check_mode(object, y_levels)
2114

2215
# if descriptors are needed, update descr_env with the calculated values
@@ -150,14 +143,7 @@ form_xy <- function(object, control, env,
150143
env$x <- data_obj$x
151144
env$y <- data_obj$y
152145

153-
res <- list(lvl = levels_from_formula(env$formula, env$data), spec = object)
154-
if (object$mode == "classification") {
155-
if (is.null(res$lvl))
156-
rlang::abort("For a classification model, the outcome should be a factor.")
157-
} else if (object$mode == "regression") {
158-
if (!is.null(res$lvl))
159-
rlang::abort("For a regression model, the outcome should be numeric.")
160-
}
146+
check_outcome(env$y, object)
161147

162148
res <- xy_xy(
163149
object = object,

tests/testthat/test_linear_reg.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,16 @@ test_that('lm execution', {
247247
regexp = "For a regression model"
248248
)
249249

250+
expect_error(
251+
res <- fit_xy(
252+
hpc_basic,
253+
x = hpc[, num_pred],
254+
y = as.character(hpc$class),
255+
control = ctrl
256+
),
257+
regexp = "For a regression model"
258+
)
259+
250260
expect_error(
251261
res <- fit(
252262
hpc_basic,

0 commit comments

Comments
 (0)