Skip to content

Commit 9b4c0b8

Browse files
topepojuliasilge
andauthored
Name translations (#735)
* function for translating names * update argument name * some unit tests * global var false positive * add missing pkgdown entry * Fix typos and refine docs * Less confusing wording Co-authored-by: Julia Silge <[email protected]>
1 parent f68376a commit 9b4c0b8

File tree

8 files changed

+120
-2
lines changed

8 files changed

+120
-2
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ export(.convert_xy_to_form_new)
156156
export(.dat)
157157
export(.facts)
158158
export(.lvls)
159+
export(.model_param_name_key)
159160
export(.obs)
160161
export(.organize_glmnet_pred)
161162
export(.preds)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919

2020
* `predict(type = "prob")` will now provide an error if the outcome variable has a level called `"class"` (#720).
2121

22+
* Added a developer function, `.model_param_name_key` that translates names of tuning parameters.
23+
2224
* Model type functions will now message informatively if a needed parsnip extension package is not loaded (#731).
2325

2426

R/parsnip-package.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ utils::globalVariables(
4141
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
4242
"compute_intercept", "remove_intercept", "estimate", "term",
4343
"call_info", "component", "component_id", "func", "tunable", "label",
44-
"pkg", ".order", "item", "tunable", "has_ext", "weights", "has_wts", "protect"
44+
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts", "protect"
4545
)
4646
)
4747

R/translate.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,3 +179,62 @@ add_methods <- function(x, engine) {
179179
x$method <- get_model_spec(specific_model(x), x$mode, x$engine)
180180
x
181181
}
182+
183+
184+
#' Translate names of model tuning parameters
185+
#'
186+
#' This function creates a key that connects the identifiers users make for
187+
#' tuning parameter names, the standardized parsnip parameter names, and the
188+
#' argument names to the underlying fit function for the engine.
189+
#'
190+
#' @param object A workflow or parsnip model specification.
191+
#' @param as_tibble A logical. Should the results be in a tibble (the default)
192+
#' or in a list that can facilitate renaming grid objects?
193+
#' @return A tibble with columns `user`, `parsnip`, and `engine`, or a list
194+
#' with named character vectors `user_to_parsnip` and `parsnip_to_engine`.
195+
#' @examples
196+
#' mod <-
197+
#' linear_reg(penalty = tune("regularization"), mixture = tune()) %>%
198+
#' set_engine("glmnet")
199+
#'
200+
#' mod %>% .model_param_name_key()
201+
#'
202+
#' rn <- mod %>% .model_param_name_key(as_tibble = FALSE)
203+
#' rn
204+
#'
205+
#' grid <- tidyr::crossing(regularization = c(0, 1), mixture = (0:3) / 3)
206+
#'
207+
#' grid %>%
208+
#' dplyr::rename(!!!rn$user_to_parsnip)
209+
#'
210+
#' grid %>%
211+
#' dplyr::rename(!!!rn$user_to_parsnip) %>%
212+
#' dplyr::rename(!!!rn$parsnip_to_engine)
213+
#' @export
214+
.model_param_name_key <- function(object, as_tibble = TRUE) {
215+
if (!inherits(object, c("model_spec", "workflow"))) {
216+
rlang::abort("'object' should be a model specification or workflow.")
217+
}
218+
if (inherits(object, "workflow")) {
219+
object <- hardhat::extract_spec_parsnip(object)
220+
}
221+
222+
# To translate from given names/ids in grid to parsnip names:
223+
params <- object %>% hardhat::extract_parameter_set_dials()
224+
params <- tibble::as_tibble(params) %>%
225+
dplyr::select(user = id, parsnip = name)
226+
# Go from parsnip names to engine names
227+
arg_key <- get_from_env(paste0(class(object)[1], "_args")) %>%
228+
dplyr::filter(engine == object$engine) %>%
229+
dplyr::select(engine = original, parsnip)
230+
231+
res <- dplyr::left_join(params, arg_key, by = "parsnip")
232+
if (!as_tibble) {
233+
res0 <- list(user_to_parsnip = res$user, parsnip_to_engine = res$parsnip)
234+
names(res0$user_to_parsnip) <- res$parsnip
235+
names(res0$parsnip_to_engine) <- res$engine
236+
res <- res0
237+
}
238+
res
239+
}
240+

_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,4 @@ reference:
101101
- required_pkgs
102102
- required_pkgs.model_spec
103103
- req_pkgs
104+
- .model_param_name_key

man/dot-model_param_name_key.Rd

Lines changed: 42 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

parsnip.Rproj

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,5 @@ StripTrailingWhitespace: Yes
1717

1818
BuildType: Package
1919
PackageUseDevtools: Yes
20-
PackageCleanBeforeInstall: Yes
2120
PackageInstallArgs: --no-multiarch --with-keep.source
2221
PackageRoxygenize: rd,collate,namespace

tests/testthat/test_translate.R

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,3 +269,17 @@ test_that("arguments (svm_rbf)", {
269269
expect_snapshot(translate_args(rbf_sigma %>% set_engine("kernlab")))
270270
})
271271

272+
# ------------------------------------------------------------------------------
273+
274+
test_that("translate tuning paramter names", {
275+
276+
mod <- boost_tree(trees = tune("number of trees"), min_n = tune(), tree_depth = 3)
277+
278+
expect_snapshot(.model_param_name_key(mod))
279+
expect_snapshot(.model_param_name_key(mod, as_tibble = FALSE))
280+
expect_snapshot(.model_param_name_key(linear_reg()))
281+
expect_snapshot(.model_param_name_key(linear_reg(), as_tibble = FALSE))
282+
expect_snapshot_error(.model_param_name_key(1))
283+
})
284+
285+

0 commit comments

Comments
 (0)