Skip to content

Commit 3c88e50

Browse files
authored
Merge pull request #606 from tidymodels/tunable
move methods for `tunable()` from tune to parsnip
2 parents 3178f35 + 8ef2694 commit 3c88e50

File tree

3 files changed

+251
-1
lines changed

3 files changed

+251
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 0.1.7.9002
3+
Version: 0.1.7.9003
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),

R/tunable.R

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Lazily registered in .onLoad()
2+
tunable_model_spec <- function(x, ...) {
3+
mod_env <- rlang::ns_env("parsnip")$parsnip
4+
5+
if (is.null(x$engine)) {
6+
stop("Please declare an engine first using `set_engine()`.", call. = FALSE)
7+
}
8+
9+
arg_name <- paste0(mod_type(x), "_args")
10+
if (!(any(arg_name == names(mod_env)))) {
11+
stop("The `parsnip` model database doesn't know about the arguments for ",
12+
"model `", mod_type(x), "`. Was it registered?",
13+
sep = "", call. = FALSE)
14+
}
15+
16+
arg_vals <-
17+
mod_env[[arg_name]] %>%
18+
dplyr::filter(engine == x$engine) %>%
19+
dplyr::select(name = parsnip, call_info = func) %>%
20+
dplyr::full_join(
21+
tibble::tibble(name = c(names(x$args), names(x$eng_args))),
22+
by = "name"
23+
) %>%
24+
dplyr::mutate(
25+
source = "model_spec",
26+
component = mod_type(x),
27+
component_id = dplyr::if_else(name %in% names(x$args), "main", "engine")
28+
)
29+
30+
if (nrow(arg_vals) > 0) {
31+
has_info <- purrr::map_lgl(arg_vals$call_info, is.null)
32+
rm_list <- !(has_info & (arg_vals$component_id == "main"))
33+
34+
arg_vals <- arg_vals[rm_list,]
35+
}
36+
arg_vals %>% dplyr::select(name, call_info, source, component, component_id)
37+
}
38+
39+
mod_type <- function(.mod) class(.mod)[class(.mod) != "model_spec"][1]
40+
41+
# ------------------------------------------------------------------------------
42+
43+
add_engine_parameters <- function(pset, engines) {
44+
is_engine_param <- pset$name %in% engines$name
45+
if (any(is_engine_param)) {
46+
engine_names <- pset$name[is_engine_param]
47+
pset <- pset[!is_engine_param,]
48+
pset <-
49+
dplyr::bind_rows(pset, engines %>% dplyr::filter(name %in% engines$name))
50+
}
51+
pset
52+
}
53+
54+
c5_tree_engine_args <-
55+
tibble::tibble(
56+
name = c(
57+
"CF",
58+
"noGlobalPruning",
59+
"winnow",
60+
"fuzzyThreshold",
61+
"bands"
62+
),
63+
call_info = list(
64+
list(pkg = "dials", fun = "confidence_factor"),
65+
list(pkg = "dials", fun = "no_global_pruning"),
66+
list(pkg = "dials", fun = "predictor_winnowing"),
67+
list(pkg = "dials", fun = "fuzzy_thresholding"),
68+
list(pkg = "dials", fun = "rule_bands")
69+
),
70+
source = "model_spec",
71+
component = "decision_tree",
72+
component_id = "engine"
73+
)
74+
75+
c5_boost_engine_args <- c5_tree_engine_args
76+
c5_boost_engine_args$component <- "boost_tree"
77+
78+
xgboost_engine_args <-
79+
tibble::tibble(
80+
name = c(
81+
"alpha",
82+
"lambda",
83+
"scale_pos_weight"
84+
),
85+
call_info = list(
86+
list(pkg = "dials", fun = "penalty_L1"),
87+
list(pkg = "dials", fun = "penalty_L2"),
88+
list(pkg = "dials", fun = "scale_pos_weight")
89+
),
90+
source = "model_spec",
91+
component = "boost_tree",
92+
component_id = "engine"
93+
)
94+
95+
ranger_engine_args <-
96+
tibble::tibble(
97+
name = c(
98+
"regularization.factor",
99+
"regularization.usedepth",
100+
"alpha",
101+
"minprop",
102+
"splitrule",
103+
"num.random.splits"
104+
),
105+
call_info = list(
106+
list(pkg = "dials", fun = "regularization_factor"),
107+
list(pkg = "dials", fun = "regularize_depth"),
108+
list(pkg = "dials", fun = "significance_threshold"),
109+
list(pkg = "dials", fun = "lower_quantile"),
110+
list(pkg = "dials", fun = "splitting_rule"),
111+
list(pkg = "dials", fun = "num_random_splits")
112+
),
113+
source = "model_spec",
114+
component = "rand_forest",
115+
component_id = "engine"
116+
)
117+
118+
randomForest_engine_args <-
119+
tibble::tibble(
120+
name = c("maxnodes"),
121+
call_info = list(
122+
list(pkg = "dials", fun = "max_nodes")
123+
),
124+
source = "model_spec",
125+
component = "rand_forest",
126+
component_id = "engine"
127+
)
128+
129+
earth_engine_args <-
130+
tibble::tibble(
131+
name = c("nk"),
132+
call_info = list(
133+
list(pkg = "dials", fun = "max_num_terms")
134+
),
135+
source = "model_spec",
136+
component = "mars",
137+
component_id = "engine"
138+
)
139+
140+
# ------------------------------------------------------------------------------
141+
142+
# Lazily registered in .onLoad()
143+
tunable_linear_reg <- function(x, ...) {
144+
res <- NextMethod()
145+
if (x$engine == "glmnet") {
146+
res$call_info[res$name == "mixture"] <-
147+
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
148+
}
149+
res
150+
}
151+
152+
# Lazily registered in .onLoad()
153+
tunable_logistic_reg <- function(x, ...) {
154+
res <- NextMethod()
155+
if (x$engine == "glmnet") {
156+
res$call_info[res$name == "mixture"] <-
157+
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
158+
}
159+
res
160+
}
161+
162+
# Lazily registered in .onLoad()
163+
tunable_multinomial_reg <- function(x, ...) {
164+
res <- NextMethod()
165+
if (x$engine == "glmnet") {
166+
res$call_info[res$name == "mixture"] <-
167+
list(list(pkg = "dials", fun = "mixture", range = c(0.05, 1.00)))
168+
}
169+
res
170+
}
171+
172+
# Lazily registered in .onLoad()
173+
tunable_boost_tree <- function(x, ...) {
174+
res <- NextMethod()
175+
if (x$engine == "xgboost") {
176+
res <- add_engine_parameters(res, xgboost_engine_args)
177+
res$call_info[res$name == "sample_size"] <-
178+
list(list(pkg = "dials", fun = "sample_prop"))
179+
} else {
180+
if (x$engine == "C5.0") {
181+
res <- add_engine_parameters(res, c5_boost_engine_args)
182+
res$call_info[res$name == "trees"] <-
183+
list(list(pkg = "dials", fun = "trees", range = c(1, 100)))
184+
res$call_info[res$name == "sample_size"] <-
185+
list(list(pkg = "dials", fun = "sample_prop"))
186+
}
187+
}
188+
res
189+
}
190+
191+
# Lazily registered in .onLoad()
192+
tunable_rand_forest <- function(x, ...) {
193+
res <- NextMethod()
194+
if (x$engine == "ranger") {
195+
res <- add_engine_parameters(res, ranger_engine_args)
196+
}
197+
if (x$engine == "randomForest") {
198+
res <- add_engine_parameters(res, randomForest_engine_args)
199+
}
200+
res
201+
}
202+
203+
# Lazily registered in .onLoad()
204+
tunable_mars <- function(x, ...) {
205+
res <- NextMethod()
206+
if (x$engine == "earth") {
207+
res <- add_engine_parameters(res, earth_engine_args)
208+
}
209+
res
210+
}
211+
212+
# Lazily registered in .onLoad()
213+
tunable_decision_tree <- function(x, ...) {
214+
res <- NextMethod()
215+
if (x$engine == "C5.0") {
216+
res <- add_engine_parameters(res, c5_tree_engine_args)
217+
}
218+
res
219+
}
220+
221+
# Lazily registered in .onLoad()
222+
tunable_svm_poly <- function(x, ...) {
223+
res <- NextMethod()
224+
if (x$engine == "kernlab") {
225+
res$call_info[res$name == "degree"] <-
226+
list(list(pkg = "dials", fun = "prod_degree", range = c(1L, 3L)))
227+
}
228+
res
229+
}

R/zzz.R

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,27 @@
2323
# `tune_args.model_spec()` moved from tune to parsnip
2424
vctrs::s3_register("generics::tune_args", "model_spec", tune_args_model_spec)
2525
}
26+
27+
# - If tune isn't installed, register the method (`packageVersion()` will error here)
28+
# - If tune >= 0.1.6.9002 is installed, register the method
29+
should_register_tunable_method <- tryCatch(
30+
expr = utils::packageVersion("tune") >= "0.1.6.9002",
31+
error = function(cnd) TRUE
32+
)
33+
34+
if (should_register_tunable_method) {
35+
# `tunable.model_spec()` and friends moved from tune to parsnip
36+
vctrs::s3_register("generics::tunable", "model_spec", tunable_model_spec)
37+
vctrs::s3_register("generics::tunable", "linear_reg", tunable_linear_reg)
38+
vctrs::s3_register("generics::tunable", "logistic_reg", tunable_logistic_reg)
39+
vctrs::s3_register("generics::tunable", "multinomial_reg", tunable_multinomial_reg)
40+
vctrs::s3_register("generics::tunable", "boost_tree", tunable_boost_tree)
41+
vctrs::s3_register("generics::tunable", "rand_forest", tunable_rand_forest)
42+
vctrs::s3_register("generics::tunable", "mars", tunable_mars)
43+
vctrs::s3_register("generics::tunable", "decision_tree", tunable_decision_tree)
44+
vctrs::s3_register("generics::tunable", "svm_poly", tunable_svm_poly)
45+
}
46+
2647
}
2748

2849

0 commit comments

Comments
 (0)