Skip to content

Commit 22c87a8

Browse files
authored
speed up tunable.model_spec() (#921)
1 parent 65b7201 commit 22c87a8

File tree

1 file changed

+27
-18
lines changed

1 file changed

+27
-18
lines changed

R/tunable.R

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
#' @export
66
tunable.model_spec <- function(x, ...) {
7-
mod_env <- rlang::ns_env("parsnip")$parsnip
7+
8+
mod_env <- get_model_env()
89

910
if (is.null(x$engine)) {
1011
stop("Please declare an engine first using `set_engine()`.", call. = FALSE)
@@ -17,27 +18,35 @@ tunable.model_spec <- function(x, ...) {
1718
sep = "", call. = FALSE)
1819
}
1920

20-
arg_vals <-
21-
mod_env[[arg_name]] %>%
22-
dplyr::filter(engine == x$engine) %>%
23-
dplyr::select(name = parsnip, call_info = func) %>%
24-
dplyr::full_join(
25-
tibble::tibble(name = c(names(x$args), names(x$eng_args))),
26-
by = "name"
27-
) %>%
28-
dplyr::mutate(
29-
source = "model_spec",
30-
component = mod_type(x),
31-
component_id = dplyr::if_else(name %in% names(x$args), "main", "engine")
21+
arg_vals <- mod_env[[arg_name]]
22+
arg_vals <- arg_vals[arg_vals$engine == x$engine, c("parsnip", "func")]
23+
names(arg_vals)[names(arg_vals) == "parsnip"] <- "name"
24+
names(arg_vals)[names(arg_vals) == "func"] <- "call_info"
25+
26+
extra_args <- c(names(x$args), names(x$eng_args))
27+
extra_args <- extra_args[!extra_args %in% arg_vals$name]
28+
29+
extra_args_tbl <-
30+
tibble::new_tibble(
31+
list(name = extra_args, call_info = vector("list", vctrs::vec_size(extra_args))),
32+
nrow = vctrs::vec_size(extra_args)
3233
)
3334

34-
if (nrow(arg_vals) > 0) {
35-
has_info <- purrr::map_lgl(arg_vals$call_info, is.null)
36-
rm_list <- !(has_info & (arg_vals$component_id == "main"))
35+
res <- vctrs::vec_rbind(arg_vals, extra_args_tbl)
3736

38-
arg_vals <- arg_vals[rm_list,]
37+
res$source <- "model_spec"
38+
res$component <- mod_type(x)
39+
res$component_id <- "main"
40+
res$component_id[!res$name %in% names(x$args)] <- "engine"
41+
42+
if (nrow(res) > 0) {
43+
has_info <- purrr::map_lgl(res$call_info, is.null)
44+
rm_list <- !(has_info & (res$component_id == "main"))
45+
46+
res <- res[rm_list, ]
3947
}
40-
arg_vals %>% dplyr::select(name, call_info, source, component, component_id)
48+
49+
res[, c("name", "call_info", "source", "component", "component_id")]
4150
}
4251

4352
mod_type <- function(.mod) class(.mod)[class(.mod) != "model_spec"][1]

0 commit comments

Comments
 (0)