Skip to content

Commit 1b6f8b9

Browse files
authored
Merge pull request #571 from tidymodels/method-tune_args
move method for `tune_args()` from tune to here
2 parents 51734c4 + e2771c3 commit 1b6f8b9

File tree

3 files changed

+190
-1
lines changed

3 files changed

+190
-1
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ Depends:
2828
R (>= 2.10)
2929
Imports:
3030
dplyr (>= 0.8.0.1),
31-
generics (>= 0.1.0),
31+
generics (>= 0.1.0.9000),
3232
globals,
3333
glue,
3434
hardhat (>= 0.1.5.9000),

R/tune_args.R

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
2+
# Lazily registered in .onLoad()
3+
tune_args_model_spec <- function(object, full = FALSE, ...) {
4+
5+
# use the model_spec top level class as the id
6+
model_type <- class(object)[1]
7+
8+
if (length(object$args) == 0L & length(object$eng_args) == 0L) {
9+
return(tune_tbl())
10+
}
11+
12+
# Locate tunable args in spec args and engine specific args
13+
object$args <- purrr::map(object$args, convert_args)
14+
object$eng_args <- purrr::map(object$eng_args, convert_args)
15+
16+
arg_id <- purrr::map_chr(object$args, find_tune_id)
17+
eng_arg_id <- purrr::map_chr(object$eng_args, find_tune_id)
18+
res <- c(arg_id, eng_arg_id)
19+
res <- ifelse(res == "", names(res), res)
20+
21+
tune_tbl(
22+
name = names(res),
23+
tunable = unname(!is.na(res)),
24+
id = res,
25+
source = "model_spec",
26+
component = model_type,
27+
component_id = NA_character_,
28+
full = full
29+
)
30+
}
31+
32+
33+
34+
# helpers for tune_args() methods -----------------------------------------
35+
# they also exist in recipes for the `tune_args()` methods there
36+
37+
38+
# If we map over a list or arguments and some are quosures, we get the message
39+
# that "Subsetting quosures with `[[` is deprecated as of rlang 0.4.0"
40+
41+
convert_args <- function(x) {
42+
if (rlang::is_quosure(x)) {
43+
x <- rlang::quo_get_expr(x)
44+
}
45+
x
46+
}
47+
48+
49+
# useful for standardization and for creating a 0 row tunable tbl
50+
# (i.e. for when there are no steps in a recipe)
51+
tune_tbl <- function(name = character(),
52+
tunable = logical(),
53+
id = character(),
54+
source = character(),
55+
component = character(),
56+
component_id = character(),
57+
full = FALSE) {
58+
59+
60+
complete_id <- id[!is.na(id)]
61+
dups <- duplicated(complete_id)
62+
if (any(dups)) {
63+
stop("There are duplicate `id` values listed in [tune()]: ",
64+
paste0("'", unique(complete_id[dups]), "'", collapse = ", "),
65+
".", sep = "", call. = FALSE)
66+
}
67+
68+
vry_tbl <- tibble::tibble(
69+
name = as.character(name),
70+
tunable = as.logical(tunable),
71+
id = as.character(id),
72+
source = as.character(source),
73+
component = as.character(component),
74+
component_id = as.character(component_id)
75+
)
76+
77+
if (!full) {
78+
vry_tbl <- vry_tbl[vry_tbl$tunable,]
79+
}
80+
81+
vry_tbl
82+
}
83+
84+
# Return the `id` arg in tune(); if not specified, then returns "" or if not
85+
# a tunable arg then returns NA_character_
86+
tune_id <- function(x) {
87+
if (is.null(x)) {
88+
return(NA_character_)
89+
} else {
90+
if (rlang::is_quosures(x)) {
91+
# Try to evaluate to catch things in the global envir.
92+
.x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE)
93+
if (inherits(.x, "try-error")) {
94+
x <- purrr::map(x, rlang::quo_get_expr)
95+
} else {
96+
x <- .x
97+
}
98+
if (is.null(x)) {
99+
return(NA_character_)
100+
}
101+
}
102+
103+
# [tune()] will always return a call object
104+
if (is.call(x)) {
105+
if (rlang::call_name(x) == "tune") {
106+
# If an id was specified:
107+
if (length(x) > 1) {
108+
return(x[[2]])
109+
} else {
110+
# no id
111+
return("")
112+
}
113+
return(x$id)
114+
} else {
115+
return(NA_character_)
116+
}
117+
}
118+
}
119+
NA_character_
120+
}
121+
122+
find_tune_id <- function(x) {
123+
124+
# STEP 1 - Early exits
125+
126+
# Early exit for empty elements (like list())
127+
if (length(x) == 0L) {
128+
return(NA_character_)
129+
}
130+
131+
# turn quosures into expressions before continuing
132+
if (rlang::is_quosures(x)) {
133+
# Try to evaluate to catch things in the global envir. If it is a dplyr
134+
# selector, it will fail to evaluate.
135+
.x <- try(purrr::map(x, rlang::eval_tidy), silent = TRUE)
136+
if (inherits(.x, "try-error")) {
137+
x <- purrr::map(x, rlang::quo_get_expr)
138+
} else {
139+
x <- .x
140+
}
141+
}
142+
143+
id <- tune_id(x)
144+
if (!is.na(id)) {
145+
return(id)
146+
}
147+
148+
if (is.atomic(x) | is.name(x) | length(x) == 1) {
149+
return(NA_character_)
150+
}
151+
152+
# STEP 2 - Recursion
153+
154+
# tunable_elems <- purrr::map_lgl(x, find_tune)
155+
tunable_elems <- vector("character", length = length(x))
156+
157+
# use purrr::map_lgl
158+
for (i in seq_along(x)) {
159+
tunable_elems[i] <- find_tune_id(x[[i]])
160+
}
161+
162+
tunable_elems <- tunable_elems[!is.na(tunable_elems)]
163+
if (length(tunable_elems) == 0) {
164+
tunable_elems <- NA_character_
165+
}
166+
167+
if (sum(tunable_elems == "", na.rm = TRUE) > 1) {
168+
stop(
169+
"Only one tunable value is currently allowed per argument. ",
170+
"The current argument has: `",
171+
paste0(deparse(x), collapse = ""),
172+
"`.",
173+
call. = FALSE)
174+
}
175+
176+
return(tunable_elems)
177+
}

R/zzz.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,18 @@
1111
s3_register("generics::augment", "model_fit")
1212
s3_register("generics::required_pkgs", "model_fit")
1313
s3_register("generics::required_pkgs", "model_spec")
14+
15+
# - If tune isn't installed, register the method (`packageVersion()` will error here)
16+
# - If tune >= 0.1.6.9001 is installed, register the method
17+
should_register_tune_args_method <- tryCatch(
18+
expr = utils::packageVersion("tune") >= "0.1.6.9001",
19+
error = function(cnd) TRUE
20+
)
21+
22+
if (should_register_tune_args_method) {
23+
# `tune_args.model_spec()` moved from tune to parsnip
24+
vctrs::s3_register("generics::tune_args", "model_spec", tune_args_model_spec)
25+
}
1426
}
1527

1628

0 commit comments

Comments
 (0)