Skip to content

Commit 5dc131d

Browse files
committed
added some replacement functions for new model structure
1 parent a4c0955 commit 5dc131d

File tree

1 file changed

+52
-3
lines changed

1 file changed

+52
-3
lines changed

R/translate.R

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@
1919
#' This function can be useful when you need to understand how
2020
#' `parsnip` goes from a generic model specific to a model fitting
2121
#' function.
22-
#'
22+
#'
2323
#' **Note**: this function is used internally and users should only use it
2424
#' to understand what the underlying syntax would be. It should not be used
25-
#' to modify the model specification.
26-
#'
25+
#' to modify the model specification.
26+
#'
2727
#' @examples
2828
#' lm_spec <- linear_reg(penalty = 0.01)
2929
#'
@@ -144,3 +144,52 @@ check_mode <- function(object, lvl) {
144144
}
145145
object
146146
}
147+
148+
# ------------------------------------------------------------------------------
149+
# new code for revised model data structures
150+
151+
get_model_spec <- function(model, mode, engine) {
152+
m_env <- get_model_env()
153+
env_obj <- env_names(m_env)
154+
env_obj <- grep(model, env_obj, value = TRUE)
155+
156+
res <- list()
157+
res$libs <-
158+
env_get(m_env, paste0(model, "_pkgs")) %>%
159+
purrr::pluck("pkg") %>%
160+
purrr::pluck(1)
161+
162+
res$fit <-
163+
env_get(m_env, paste0(model, "_fit")) %>%
164+
dplyr::filter(mode == !!mode & engine == !!engine) %>%
165+
dplyr::pull(value) %>%
166+
purrr:::pluck(1)
167+
168+
pred_code <-
169+
env_get(m_env, paste0(model, "_predict")) %>%
170+
dplyr::filter(mode == !!mode & engine == !!engine) %>%
171+
dplyr::select(-engine, -mode)
172+
173+
res$pred <- pred_code[["value"]]
174+
names(res$pred) <- pred_code$type
175+
176+
res
177+
}
178+
179+
get_args <- function(model, engine) {
180+
m_env <- get_model_env()
181+
env_get(m_env, paste0(model, "_args")) %>%
182+
dplyr::select(-engine)
183+
}
184+
185+
# to replace harmonize
186+
unionize <- function(args, key) {
187+
parsn <- tibble(parsnip = names(args), order = seq_along(args))
188+
merged <-
189+
dplyr::left_join(parsn, key, by = "parsnip") %>%
190+
dplyr::arrange(order)
191+
# TODO correct for bad merge?
192+
193+
names(args) <- merged$original
194+
args[!is.na(merged$original)]
195+
}

0 commit comments

Comments
 (0)