Skip to content

Commit b1cb0bd

Browse files
authored
Check valid spec mode (#470)
* New function to check model_spec and mode compatibility * reformat to style guidelines * refactoring - new function check_spec_mode_val is a drop and replace for this chunk * include new function in `set_mode` * adding tests to confirm expected `set_modes` works with the base model_spec objects of parsnip. Includes at least one expect_error per model_spec. * small cleanup on null_model and surv_reg * adding description of changes to NEWS.md * commiting changes suggested by DavisVaughan
1 parent 639fe1a commit b1cb0bd

File tree

5 files changed

+26
-8
lines changed

5 files changed

+26
-8
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212

1313
* Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462).
1414

15+
* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467).
16+
1517
* Re-organized model documentation for `update` methods (#479).
1618

19+
1720
# parsnip 0.1.5
1821

1922
* An RStudio add-in is available that makes writing multiple `parsnip` model specifications to the source window. It can be accessed via the IDE addin menus or by calling `parsnip_addin()`.

R/aaa_models.R

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,19 @@ check_mode_val <- function(mode) {
132132
invisible(NULL)
133133
}
134134

135+
# check if class and mode are compatible
136+
check_spec_mode_val <- function(cls, mode) {
137+
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
138+
if (!(mode %in% spec_modes))
139+
rlang::abort(
140+
glue::glue(
141+
"`mode` should be one of: ",
142+
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
143+
)
144+
)
145+
invisible(NULL)
146+
}
147+
135148
check_engine_val <- function(eng) {
136149
if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng))
137150
rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).")

R/arguments.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ set_mode <- function(object, mode) {
8787
)
8888
)
8989
}
90+
check_spec_mode_val(class(object)[1], mode)
9091
object$mode <- mode
9192
object
9293
}

R/misc.R

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -191,14 +191,8 @@ update_dot_check <- function(...) {
191191
#' @keywords internal
192192
#' @rdname add_on_exports
193193
new_model_spec <- function(cls, args, eng_args, mode, method, engine) {
194-
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
195-
if (!(mode %in% spec_modes))
196-
rlang::abort(
197-
glue::glue(
198-
"`mode` should be one of: ",
199-
glue::glue_collapse(glue::glue("'{spec_modes}'"), sep = ", ")
200-
)
201-
)
194+
195+
check_spec_mode_val(cls, mode)
202196

203197
out <- list(args = args, eng_args = eng_args,
204198
mode = mode, method = method, engine = engine)

tests/testthat/test_args_and_modes.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,10 @@ test_that('pipe engine', {
4545
expect_error(rand_forest() %>% set_mode(2))
4646
expect_error(rand_forest() %>% set_mode("haberdashery"))
4747
})
48+
49+
test_that("can't set a mode that isn't allowed by the model spec", {
50+
expect_error(
51+
set_mode(linear_reg(), "classification"),
52+
"`mode` should be one of"
53+
)
54+
})

0 commit comments

Comments
 (0)