Skip to content

Commit a546656

Browse files
authored
Merge pull request #488 from tidymodels/missing-engine-mode
Improve errors for `set_mode()` and `set_engine()`
2 parents fc21c9e + ba99fb2 commit a546656

13 files changed

+40
-35
lines changed

R/aaa_models.R

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,21 +127,27 @@ check_model_doesnt_exist <- function(model) {
127127
}
128128

129129
check_mode_val <- function(mode) {
130-
if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode))
130+
if (rlang::is_missing(mode) || length(mode) != 1 || !is.character(mode)) {
131131
rlang::abort("Please supply a character string for a mode (e.g. `'regression'`).")
132+
}
132133
invisible(NULL)
133134
}
134135

135136
# check if class and mode are compatible
136137
check_spec_mode_val <- function(cls, mode) {
137138
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
138-
if (!(mode %in% spec_modes))
139-
rlang::abort(
140-
glue::glue(
141-
"`mode` should be one of: ",
142-
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
143-
)
139+
compatible_modes <-
140+
glue::glue(
141+
"`mode` should be one of: ",
142+
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
144143
)
144+
145+
if (is.null(mode)) {
146+
rlang::abort(compatible_modes)
147+
} else if (!(mode %in% spec_modes)) {
148+
rlang::abort(compatible_modes)
149+
}
150+
145151
invisible(NULL)
146152
}
147153

R/arguments.R

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,10 @@ set_args <- function(object, ...) {
7676
#' @rdname set_args
7777
#' @export
7878
set_mode <- function(object, mode) {
79-
if (is.null(mode))
80-
return(object)
81-
mode <- mode[1]
82-
if (!(any(all_modes == mode))) {
83-
rlang::abort(
84-
glue::glue(
85-
"`mode` should be one of ",
86-
glue::glue_collapse(glue::glue("'{all_modes}'"), sep = ", ")
87-
)
88-
)
79+
if (rlang::is_missing(mode)) {
80+
mode <- NULL
8981
}
82+
mode <- mode[1]
9083
check_spec_mode_val(class(object)[1], mode)
9184
object$mode <- mode
9285
object

R/engines.R

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ check_engine <- function(object) {
1515
if (is.null(object$engine)) {
1616
object$engine <- avail_eng[1]
1717
rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`"))
18+
} else {
19+
if (!is.character(object$engine) | length(object$engine) != 1) {
20+
rlang::abort("`engine` should be a single character value.")
21+
}
1822
}
1923
if (!(object$engine %in% avail_eng)) {
2024
rlang::abort(
@@ -91,18 +95,20 @@ set_engine <- function(object, engine, ...) {
9195
if (!inherits(object, "model_spec")) {
9296
rlang::abort("`object` should have class 'model_spec'.")
9397
}
94-
if (!is.character(engine) | length(engine) != 1)
95-
rlang::abort("`engine` should be a single character value.")
96-
if (engine == "liquidSVM") {
98+
99+
if (rlang::is_missing(engine)) {
100+
engine <- NULL
101+
}
102+
object$engine <- engine
103+
object <- check_engine(object)
104+
105+
if (object$engine == "liquidSVM") {
97106
lifecycle::deprecate_soft(
98107
"0.1.6",
99108
"set_engine(engine = 'cannot be liquidSVM')",
100109
details = "The liquidSVM package is no longer available on CRAN.")
101110
}
102111

103-
object$engine <- engine
104-
object <- check_engine(object)
105-
106112
new_model_spec(
107113
cls = class(object)[1],
108114
args = object$args,

tests/testthat/test_mars.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@ test_that('updating', {
111111
})
112112

113113
test_that('bad input', {
114+
expect_warning(translate(mars(mode = "regression") %>% set_engine()))
114115
expect_error(translate(mars() %>% set_engine("wat?")))
115-
expect_error(translate(mars(mode = "regression") %>% set_engine()))
116116
expect_error(translate(mars(formula = y ~ x)))
117117
})
118118

tests/testthat/test_multinom_reg.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ test_that('updating', {
122122

123123
test_that('bad input', {
124124
expect_error(multinom_reg(mode = "regression"))
125-
expect_error(translate(multinom_reg() %>% set_engine("wat?")))
126-
expect_error(translate(multinom_reg() %>% set_engine()))
127-
expect_warning(translate(multinom_reg(penalty = 0.01) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
125+
expect_error(translate(multinom_reg(penalty = 0.1) %>% set_engine("wat?")))
126+
expect_warning(multinom_reg(penalty = 0.1) %>% set_engine())
127+
expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
128128
})

tests/testthat/test_nearest_neighbor.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,5 @@ test_that('updating', {
122122

123123
test_that('bad input', {
124124
expect_error(nearest_neighbor(mode = "reallyunknown"))
125-
expect_error(translate(nearest_neighbor() %>% set_engine( NULL)))
125+
expect_warning(nearest_neighbor() %>% set_engine( NULL))
126126
})

tests/testthat/test_nullmodel.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ test_that('engine arguments', {
3232
})
3333

3434
test_that('bad input', {
35+
expect_warning(translate(null_model(mode = "regression") %>% set_engine()))
3536
expect_error(translate(null_model() %>% set_engine("wat?")))
36-
expect_error(translate(null_model(mode = "regression") %>% set_engine()))
3737
expect_error(translate(null_model(formula = y ~ x)))
3838
expect_warning(
3939
translate(

tests/testthat/test_rand_forest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,9 +192,9 @@ test_that('updating', {
192192
})
193193

194194
test_that('bad input', {
195+
expect_warning(translate(rand_forest(mode = "classification") %>% set_engine(NULL)))
195196
expect_error(rand_forest(mode = "time series"))
196197
expect_error(translate(rand_forest(mode = "classification") %>% set_engine("wat?")))
197-
expect_error(translate(rand_forest(mode = "classification") %>% set_engine(NULL)))
198198
expect_error(translate(rand_forest(mode = "classification", ytest = 2)))
199199
})
200200

tests/testthat/test_surv_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ test_that('bad input', {
8585

8686
expect_error(surv_reg(mode = ", classification"))
8787
expect_error(translate(surv_reg() %>% set_engine("wat")))
88-
expect_error(translate(surv_reg() %>% set_engine(NULL)))
88+
expect_warning(translate(surv_reg() %>% set_engine(NULL)))
8989
})
9090

9191
test_that("deprecation warning", {

tests/testthat/test_svm_linear.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ test_that('updating', {
104104
})
105105

106106
test_that('bad input', {
107+
expect_warning(translate(svm_linear(mode = "regression") %>% set_engine( NULL)))
107108
expect_error(svm_linear(mode = "reallyunknown"))
108-
expect_error(translate(svm_linear(mode = "regression") %>% set_engine( NULL)))
109109
expect_error(translate(svm_linear(mode = "regression") %>% set_engine("LiblineaR", type = 3)))
110110
expect_error(translate(svm_linear(mode = "classification") %>% set_engine("LiblineaR", type = 11)))
111111
})

tests/testthat/test_svm_liquidsvm.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,5 @@ test_that('updating', {
7777

7878
test_that('bad input', {
7979
expect_error(svm_rbf(mode = "reallyunknown"))
80-
expect_error(translate(svm_rbf() %>% set_engine( NULL)))
80+
expect_warning(svm_rbf() %>% set_engine( NULL))
8181
})

tests/testthat/test_svm_poly.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ test_that('updating', {
106106

107107
test_that('bad input', {
108108
expect_error(svm_poly(mode = "reallyunknown"))
109-
expect_error(translate(svm_poly() %>% set_engine( NULL)))
109+
expect_warning(svm_poly() %>% set_engine(NULL))
110110
})
111111

112112
# ------------------------------------------------------------------------------

tests/testthat/test_svm_rbf.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ test_that('updating', {
8787

8888
test_that('bad input', {
8989
expect_error(svm_rbf(mode = "reallyunknown"))
90-
expect_error(translate(svm_rbf(mode = "regression") %>% set_engine( NULL)))
90+
expect_warning(translate(svm_rbf(mode = "regression") %>% set_engine( NULL)))
9191
})
9292

9393
# ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)