|
19 | 19 | #' This function can be useful when you need to understand how
|
20 | 20 | #' `parsnip` goes from a generic model specific to a model fitting
|
21 | 21 | #' function.
|
22 |
| -#' |
| 22 | +#' |
23 | 23 | #' **Note**: this function is used internally and users should only use it
|
24 | 24 | #' to understand what the underlying syntax would be. It should not be used
|
25 |
| -#' to modify the model specification. |
26 |
| -#' |
| 25 | +#' to modify the model specification. |
| 26 | +#' |
27 | 27 | #' @examples
|
28 | 28 | #' lm_spec <- linear_reg(penalty = 0.01)
|
29 | 29 | #'
|
@@ -144,3 +144,52 @@ check_mode <- function(object, lvl) {
|
144 | 144 | }
|
145 | 145 | object
|
146 | 146 | }
|
| 147 | + |
| 148 | +# ------------------------------------------------------------------------------ |
| 149 | +# new code for revised model data structures |
| 150 | + |
| 151 | +get_model_spec <- function(model, mode, engine) { |
| 152 | + m_env <- get_model_env() |
| 153 | + env_obj <- env_names(m_env) |
| 154 | + env_obj <- grep(model, env_obj, value = TRUE) |
| 155 | + |
| 156 | + res <- list() |
| 157 | + res$libs <- |
| 158 | + env_get(m_env, paste0(model, "_pkgs")) %>% |
| 159 | + purrr::pluck("pkg") %>% |
| 160 | + purrr::pluck(1) |
| 161 | + |
| 162 | + res$fit <- |
| 163 | + env_get(m_env, paste0(model, "_fit")) %>% |
| 164 | + dplyr::filter(mode == !!mode & engine == !!engine) %>% |
| 165 | + dplyr::pull(value) %>% |
| 166 | + purrr:::pluck(1) |
| 167 | + |
| 168 | + pred_code <- |
| 169 | + env_get(m_env, paste0(model, "_predict")) %>% |
| 170 | + dplyr::filter(mode == !!mode & engine == !!engine) %>% |
| 171 | + dplyr::select(-engine, -mode) |
| 172 | + |
| 173 | + res$pred <- pred_code[["value"]] |
| 174 | + names(res$pred) <- pred_code$type |
| 175 | + |
| 176 | + res |
| 177 | +} |
| 178 | + |
| 179 | +get_args <- function(model, engine) { |
| 180 | + m_env <- get_model_env() |
| 181 | + env_get(m_env, paste0(model, "_args")) %>% |
| 182 | + dplyr::select(-engine) |
| 183 | +} |
| 184 | + |
| 185 | +# to replace harmonize |
| 186 | +unionize <- function(args, key) { |
| 187 | + parsn <- tibble(parsnip = names(args), order = seq_along(args)) |
| 188 | + merged <- |
| 189 | + dplyr::left_join(parsn, key, by = "parsnip") %>% |
| 190 | + dplyr::arrange(order) |
| 191 | + # TODO correct for bad merge? |
| 192 | + |
| 193 | + names(args) <- merged$original |
| 194 | + args[!is.na(merged$original)] |
| 195 | +} |
0 commit comments