Skip to content

Commit 3b52dd8

Browse files
committed
restrict fit_xy() from working for censored regression models or surv_reg models
1 parent 5865c53 commit 3b52dd8

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
* Column names for `x` are now required when `fit_xy()` is used. (#398)
1818

19+
* Censored regression models cannot use `fit_xy()` (use `fit()`). (#442)
20+
1921

2022
# parsnip 0.1.4
2123

R/fit.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,14 @@ fit_xy.model_spec <-
192192
control = control_parsnip(),
193193
...
194194
) {
195+
if (object$mode == "censored regression") {
196+
rlang::abort("Models for censored regression must use the formula interface.")
197+
}
198+
199+
if (inherits(object, "surv_reg")) {
200+
rlang::abort("Survival models must use the formula interface.")
201+
}
202+
195203
if (!identical(class(control), class(control_parsnip()))) {
196204
rlang::abort("The 'control' argument should have class 'control_parsnip'.")
197205
}
@@ -388,10 +396,6 @@ check_xy_interface <- function(x, y, cl, model) {
388396

389397
df_interface <- !is.null(x) & !is.null(y) && is.data.frame(x)
390398

391-
if (inherits(model, "surv_reg") && (matrix_interface | df_interface)) {
392-
rlang::abort("Survival models must use the formula interface.")
393-
}
394-
395399
if (matrix_interface) {
396400
return("matrix")
397401
}

tests/testthat/test_fit_interfaces.R

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
88

99
f <- y ~ x
1010

11-
smod <- surv_reg()
1211
rmod <- linear_reg()
1312

1413
sprk <- 1:10
@@ -36,7 +35,6 @@ test_that('good args', {
3635
#
3736
test_that('wrong args', {
3837
expect_error(tester_xy(NULL, x = sprk, y = hpc, model = rmod))
39-
expect_error(tester_xy(NULL, x = hpc, y = hpc$compounds, model = smod))
4038
expect_error(tester(NULL, f, data = as.matrix(hpc[, 1:4])))
4139
})
4240

0 commit comments

Comments
 (0)