Skip to content

Commit 110257f

Browse files
authored
Merge pull request #445 from tidymodels/restrict-fit_xy
restrict `fit_xy()` from working for censored regression models
2 parents 0598922 + 4a8316f commit 110257f

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
* New mode "censored regression" and new prediction types "linear_pred", "time", "survival", "hazard". (#396)
2222

23+
* Censored regression models cannot use `fit_xy()` (use `fit()`). (#442)
2324

2425
# parsnip 0.1.4
2526

R/fit.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,14 @@ fit_xy.model_spec <-
192192
control = control_parsnip(),
193193
...
194194
) {
195+
if (object$mode == "censored regression") {
196+
rlang::abort("Models for censored regression must use the formula interface.")
197+
}
198+
199+
if (inherits(object, "surv_reg")) {
200+
rlang::abort("Survival models must use the formula interface.")
201+
}
202+
195203
if (!identical(class(control), class(control_parsnip()))) {
196204
rlang::abort("The 'control' argument should have class 'control_parsnip'.")
197205
}
@@ -388,10 +396,6 @@ check_xy_interface <- function(x, y, cl, model) {
388396

389397
df_interface <- !is.null(x) & !is.null(y) && is.data.frame(x)
390398

391-
if (inherits(model, "surv_reg") && (matrix_interface | df_interface)) {
392-
rlang::abort("Survival models must use the formula interface.")
393-
}
394-
395399
if (matrix_interface) {
396400
return("matrix")
397401
}

tests/testthat/test_fit_interfaces.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
88

99
f <- y ~ x
1010

11-
smod <- surv_reg()
1211
rmod <- linear_reg()
1312

1413
sprk <- 1:10
@@ -36,7 +35,6 @@ test_that('good args', {
3635
#
3736
test_that('wrong args', {
3837
expect_error(tester_xy(NULL, x = sprk, y = hpc, model = rmod))
39-
expect_error(tester_xy(NULL, x = hpc, y = hpc$compounds, model = smod))
4038
expect_error(tester(NULL, f, data = as.matrix(hpc[, 1:4])))
4139
})
4240

0 commit comments

Comments
 (0)