Skip to content

Commit 6f5ed7e

Browse files
authored
Merge pull request #503 from tidymodels/engine-mode-errors
Harmonize errors for `set_mode()` and `set_engine()`
2 parents 46a2018 + 2ad7ec3 commit 6f5ed7e

15 files changed

+44
-42
lines changed

NEWS.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313

1414
* The xgboost engine for boosted trees was translating `mtry` to xgboost's `colsample_bytree`. We now map `mtry` to `colsample_bynode` since that is more consistent with how random forest works. `colsample_bytree` can still be optimized by passing it in as an engine argument. `colsample_bynode` was added to xgboost after the `parsnip` package code was written. (#495)
1515

16-
* For xgboost boosting, `mtry` and `colsample_bytree` can be passed as integer counts or proportions while `subsample` and `validation` should be proportions. `xgb_train()` now has a new option `counts` for state what scale `mtry` and `colsample_bytree` are being used. (#461)
16+
* For xgboost, `mtry` and `colsample_bytree` can be passed as integer counts or proportions, while `subsample` and `validation` should always be proportions. `xgb_train()` now has a new option `counts` (`TRUE` or `FALSE`) that states which scale for `mtry` and `colsample_bytree` is being used. (#461)
1717

1818
## Other Changes
1919

2020
* Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462).
2121

22-
* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467).
22+
* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467). Both `set_mode()` and `set_engine()` now error for `NULL` or missing arguments (#503).
2323

2424
* Re-organized model documentation for `update` methods (#479).
2525

R/aaa_models.R

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,23 @@ check_mode_val <- function(mode) {
133133
invisible(NULL)
134134
}
135135

136+
137+
stop_incompatible_mode <- function(spec_modes) {
138+
msg <- glue::glue(
139+
"Available modes are: ",
140+
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
141+
)
142+
rlang::abort(msg)
143+
}
144+
136145
# check if class and mode are compatible
137146
check_spec_mode_val <- function(cls, mode) {
138147
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
139-
compatible_modes <-
140-
glue::glue(
141-
"`mode` should be one of: ",
142-
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
143-
)
144-
145-
if (is.null(mode)) {
146-
rlang::abort(compatible_modes)
148+
if (is.null(mode) || length(mode) > 1) {
149+
stop_incompatible_mode(spec_modes)
147150
} else if (!(mode %in% spec_modes)) {
148-
rlang::abort(compatible_modes)
151+
stop_incompatible_mode(spec_modes)
149152
}
150-
151153
invisible(NULL)
152154
}
153155

R/arguments.R

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,12 @@ set_args <- function(object, ...) {
7676
#' @rdname set_args
7777
#' @export
7878
set_mode <- function(object, mode) {
79+
cls <- class(object)[1]
7980
if (rlang::is_missing(mode)) {
80-
mode <- NULL
81+
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
82+
stop_incompatible_mode(spec_modes)
8183
}
82-
mode <- mode[1]
83-
check_spec_mode_val(class(object)[1], mode)
84+
check_spec_mode_val(cls, mode)
8485
object$mode <- mode
8586
object
8687
}

R/engines.R

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,21 @@ possible_engines <- function(object, ...) {
1010
unique(engs$engine)
1111
}
1212

13+
stop_incompatible_engine <- function(avail_eng) {
14+
msg <- glue::glue(
15+
"Available engines are: ",
16+
glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ")
17+
)
18+
rlang::abort(msg)
19+
}
20+
1321
check_engine <- function(object) {
1422
avail_eng <- possible_engines(object)
15-
if (is.null(object$engine)) {
16-
object$engine <- avail_eng[1]
17-
rlang::warn(glue::glue("`engine` was NULL and updated to be `{object$engine}`"))
18-
} else {
19-
if (!is.character(object$engine) | length(object$engine) != 1) {
20-
rlang::abort("`engine` should be a single character value.")
21-
}
22-
}
23-
if (!(object$engine %in% avail_eng)) {
24-
rlang::abort(
25-
glue::glue(
26-
"Engine '{object$engine}' is not available. Please use one of: ",
27-
glue::glue_collapse(glue::glue("'{avail_eng}'"), sep = ", ")
28-
)
29-
)
23+
eng <- object$engine
24+
if (is.null(eng) || length(eng) > 1) {
25+
stop_incompatible_engine(avail_eng)
26+
} else if (!(eng %in% avail_eng)) {
27+
stop_incompatible_engine(avail_eng)
3028
}
3129
object
3230
}
@@ -97,7 +95,8 @@ set_engine <- function(object, engine, ...) {
9795
}
9896

9997
if (rlang::is_missing(engine)) {
100-
engine <- NULL
98+
avail_eng <- possible_engines(object)
99+
stop_incompatible_engine(avail_eng)
101100
}
102101
object$engine <- engine
103102
object <- check_engine(object)

tests/testthat/test_args_and_modes.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,6 @@ test_that('pipe engine', {
4949
test_that("can't set a mode that isn't allowed by the model spec", {
5050
expect_error(
5151
set_mode(linear_reg(), "classification"),
52-
"`mode` should be one of"
52+
"Available modes are:"
5353
)
5454
})

tests/testthat/test_mars.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ test_that('updating', {
111111
})
112112

113113
test_that('bad input', {
114-
expect_warning(translate(mars(mode = "regression") %>% set_engine()))
114+
expect_error(translate(mars(mode = "regression") %>% set_engine()))
115115
expect_error(translate(mars() %>% set_engine("wat?")))
116116
expect_error(translate(mars(formula = y ~ x)))
117117
})

tests/testthat/test_multinom_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,6 @@ test_that('updating', {
123123
test_that('bad input', {
124124
expect_error(multinom_reg(mode = "regression"))
125125
expect_error(translate(multinom_reg(penalty = 0.1) %>% set_engine("wat?")))
126-
expect_warning(multinom_reg(penalty = 0.1) %>% set_engine())
126+
expect_error(multinom_reg(penalty = 0.1) %>% set_engine())
127127
expect_warning(translate(multinom_reg(penalty = 0.1) %>% set_engine("glmnet", x = hpc[,1:3], y = hpc$class)))
128128
})

tests/testthat/test_nearest_neighbor.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,5 @@ test_that('updating', {
122122

123123
test_that('bad input', {
124124
expect_error(nearest_neighbor(mode = "reallyunknown"))
125-
expect_warning(nearest_neighbor() %>% set_engine( NULL))
125+
expect_error(nearest_neighbor() %>% set_engine( NULL))
126126
})

tests/testthat/test_nullmodel.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ test_that('engine arguments', {
3232
})
3333

3434
test_that('bad input', {
35-
expect_warning(translate(null_model(mode = "regression") %>% set_engine()))
35+
expect_error(translate(null_model(mode = "regression") %>% set_engine()))
3636
expect_error(translate(null_model() %>% set_engine("wat?")))
3737
expect_error(translate(null_model(formula = y ~ x)))
3838
expect_warning(

tests/testthat/test_rand_forest.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ test_that('updating', {
192192
})
193193

194194
test_that('bad input', {
195-
expect_warning(translate(rand_forest(mode = "classification") %>% set_engine(NULL)))
195+
expect_error(translate(rand_forest(mode = "classification") %>% set_engine(NULL)))
196196
expect_error(rand_forest(mode = "time series"))
197197
expect_error(translate(rand_forest(mode = "classification") %>% set_engine("wat?")))
198198
expect_error(translate(rand_forest(mode = "classification", ytest = 2)))

tests/testthat/test_surv_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ test_that('bad input', {
8585

8686
expect_error(surv_reg(mode = ", classification"))
8787
expect_error(translate(surv_reg() %>% set_engine("wat")))
88-
expect_warning(translate(surv_reg() %>% set_engine(NULL)))
88+
expect_error(translate(surv_reg() %>% set_engine(NULL)))
8989
})
9090

9191
test_that("deprecation warning", {

tests/testthat/test_svm_linear.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ test_that('updating', {
104104
})
105105

106106
test_that('bad input', {
107-
expect_warning(translate(svm_linear(mode = "regression") %>% set_engine( NULL)))
107+
expect_error(translate(svm_linear(mode = "regression") %>% set_engine( NULL)))
108108
expect_error(svm_linear(mode = "reallyunknown"))
109109
expect_error(translate(svm_linear(mode = "regression") %>% set_engine("LiblineaR", type = 3)))
110110
expect_error(translate(svm_linear(mode = "classification") %>% set_engine("LiblineaR", type = 11)))

tests/testthat/test_svm_liquidsvm.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,5 +77,5 @@ test_that('updating', {
7777

7878
test_that('bad input', {
7979
expect_error(svm_rbf(mode = "reallyunknown"))
80-
expect_warning(svm_rbf() %>% set_engine( NULL))
80+
expect_error(svm_rbf() %>% set_engine( NULL))
8181
})

tests/testthat/test_svm_poly.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ test_that('updating', {
106106

107107
test_that('bad input', {
108108
expect_error(svm_poly(mode = "reallyunknown"))
109-
expect_warning(svm_poly() %>% set_engine(NULL))
109+
expect_error(svm_poly() %>% set_engine(NULL))
110110
})
111111

112112
# ------------------------------------------------------------------------------

tests/testthat/test_svm_rbf.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ test_that('updating', {
8787

8888
test_that('bad input', {
8989
expect_error(svm_rbf(mode = "reallyunknown"))
90-
expect_warning(translate(svm_rbf(mode = "regression") %>% set_engine( NULL)))
90+
expect_error(translate(svm_rbf(mode = "regression") %>% set_engine( NULL)))
9191
})
9292

9393
# ------------------------------------------------------------------------------

0 commit comments

Comments
 (0)