Skip to content

Commit 7623e54

Browse files
committed
removed min_grid generic, methods, and functions; moved to tune
1 parent 7d79db6 commit 7623e54

17 files changed

+5
-524
lines changed

NAMESPACE

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,6 @@ S3method(fit_xy,model_spec)
55
S3method(has_multi_predict,default)
66
S3method(has_multi_predict,model_fit)
77
S3method(has_multi_predict,workflow)
8-
S3method(min_grid,boost_tree)
9-
S3method(min_grid,linear_reg)
10-
S3method(min_grid,logistic_reg)
11-
S3method(min_grid,mars)
12-
S3method(min_grid,model_spec)
13-
S3method(min_grid,multinom_reg)
14-
S3method(min_grid,nearest_neighbor)
158
S3method(multi_predict,"_C5.0")
169
S3method(multi_predict,"_earth")
1710
S3method(multi_predict,"_elnet")
@@ -114,13 +107,6 @@ export(linear_reg)
114107
export(logistic_reg)
115108
export(make_classes)
116109
export(mars)
117-
export(min_grid)
118-
export(min_grid.boost_tree)
119-
export(min_grid.linear_reg)
120-
export(min_grid.logistic_reg)
121-
export(min_grid.mars)
122-
export(min_grid.multinom_reg)
123-
export(min_grid.nearest_neighbor)
124110
export(mlp)
125111
export(model_printer)
126112
export(multi_predict)

R/aaa.R

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -24,93 +24,6 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
2424
res
2525
}
2626

27-
# ------------------------------------------------------------------------------
28-
# min_grid generic - put here so that the generic shows up first in the man file
29-
30-
#' Determine the minimum set of model fits
31-
#'
32-
#' `min_grid` determines exactly what models should be fit in order to
33-
#' evaluate the entire set of tuning parameter combinations. This is for
34-
#' internal use only and the API may change in the near future.
35-
#' @param x A model specification.
36-
#' @param grid A tibble with tuning parameter combinations.
37-
#' @param ... Not currently used.
38-
#' @return A tibble with the minimum tuning parameters to fit and an additional
39-
#' list column with the parameter combinations used for prediction.
40-
#' @keywords internal
41-
#' @export
42-
min_grid <- function(x, grid, ...) {
43-
# x is a `model_spec` object from parsnip
44-
# grid is a tibble of tuning parameter values with names
45-
# matching the parameter names.
46-
UseMethod("min_grid")
47-
}
48-
49-
# As an example, if we fit a boosted tree model and tune over
50-
# trees = 1:20 and min_n = c(20, 30)
51-
# we should only have to fit two models:
52-
#
53-
# trees = 20 & min_n = 20
54-
# trees = 20 & min_n = 30
55-
#
56-
# The logic related to how this "mini grid" gets made is model-specific.
57-
#
58-
# To get the full set of predictions, we need to know, for each of these two
59-
# models, what values of num_terms to give to the multi_predict() function.
60-
#
61-
# The current idea is to have a list column of the extra models for prediction.
62-
# For the example above:
63-
#
64-
# # A tibble: 2 x 3
65-
# trees min_n .submodels
66-
# <dbl> <dbl> <list>
67-
# 1 20 20 <named list [1]>
68-
# 2 20 30 <named list [1]>
69-
#
70-
# and the .submodels would both be
71-
#
72-
# list(trees = 1:19)
73-
#
74-
# There are a lot of other things to consider in future versions like grids
75-
# where there are multiple columns with the same name (maybe the results of
76-
# a recipe) and so on.
77-
78-
# ------------------------------------------------------------------------------
79-
# helper functions
80-
81-
# Template for model results that do no have the sub-model feature
82-
blank_submodels <- function(grid) {
83-
grid %>%
84-
dplyr::mutate(.submodels = map(1:nrow(grid), ~ list()))
85-
}
86-
87-
get_fixed_args <- function(info) {
88-
# Get non-sub-model columns to iterate over
89-
fixed_args <- info$name[!info$has_submodel]
90-
}
91-
92-
get_submodel_info <- function(spec, grid) {
93-
param_info <-
94-
get_from_env(paste0(class(spec)[1], "_args")) %>%
95-
dplyr::filter(engine == spec$engine) %>%
96-
dplyr::select(name = parsnip, has_submodel)
97-
98-
# In case a recipe or other activity has grid parameter columns,
99-
# add those to the results
100-
grid_names <- names(grid)
101-
is_mod_param <- grid_names %in% param_info$name
102-
if (any(!is_mod_param)) {
103-
param_info <-
104-
param_info %>%
105-
dplyr::bind_rows(
106-
tibble::tibble(name = grid_names[!is_mod_param],
107-
has_submodel = FALSE)
108-
)
109-
}
110-
param_info %>% dplyr::filter(name %in% grid_names)
111-
}
112-
113-
11427
# ------------------------------------------------------------------------------
11528
# nocov
11629

R/boost_tree.R

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -514,41 +514,4 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {
514514
pred[, c(".row", "trees", nms)]
515515
}
516516

517-
# ------------------------------------------------------------------------------
518-
519-
#' @export
520-
#' @export min_grid.boost_tree
521-
#' @rdname min_grid
522-
min_grid.boost_tree <- function(x, grid, ...) {
523-
grid_names <- names(grid)
524-
param_info <- get_submodel_info(x, grid)
525-
526-
# No ability to do submodels? Finish here:
527-
if (!any(param_info$has_submodel)) {
528-
return(blank_submodels(grid))
529-
}
530-
531-
fixed_args <- get_fixed_args(param_info)
532-
533-
# For boosted trees, fit the model with the most trees (conditional on the
534-
# other parameters) so that you can do predictions on the smaller models.
535-
fit_only <-
536-
grid %>%
537-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
538-
dplyr::summarize(trees = max(trees, na.rm = TRUE)) %>%
539-
dplyr::ungroup()
540-
541-
# Add a column .submodels that is a list with what should be predicted
542-
# by `multi_predict()` (assuming `predict()` has already been executed
543-
# on the original value of 'trees')
544-
min_grid_df <-
545-
dplyr::full_join(fit_only %>% rename(max_tree = trees), grid, by = fixed_args) %>%
546-
dplyr::filter(trees != max_tree) %>%
547-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
548-
dplyr::summarize(.submodels = list(list(trees = trees))) %>%
549-
dplyr::ungroup() %>%
550-
dplyr::full_join(fit_only, grid, by = fixed_args)
551-
552-
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
553-
}
554517

R/linear_reg.R

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -359,37 +359,3 @@ multi_predict._elnet <-
359359
names(pred) <- NULL
360360
tibble(.pred = pred)
361361
}
362-
363-
364-
# ------------------------------------------------------------------------------
365-
366-
#' @export
367-
#' @export min_grid.linear_reg
368-
#' @rdname min_grid
369-
min_grid.linear_reg <- function(x, grid, ...) {
370-
371-
grid_names <- names(grid)
372-
param_info <- get_submodel_info(x, grid)
373-
374-
if (!any(param_info$has_submodel)) {
375-
return(blank_submodels(grid))
376-
}
377-
378-
fixed_args <- get_fixed_args(param_info)
379-
380-
fit_only <-
381-
grid %>%
382-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
383-
dplyr::summarize(penalty = max(penalty, na.rm = TRUE)) %>%
384-
dplyr::ungroup()
385-
386-
min_grid_df <-
387-
dplyr::full_join(fit_only %>% rename(max_penalty = penalty), grid, by = fixed_args) %>%
388-
dplyr::filter(penalty != max_penalty) %>%
389-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
390-
dplyr::summarize(.submodels = list(list(penalty = penalty))) %>%
391-
dplyr::ungroup() %>%
392-
dplyr::full_join(fit_only, grid, by = fixed_args)
393-
394-
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
395-
}

R/logistic_reg.R

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,11 +371,3 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
371371
object$spec <- eval_args(object$spec)
372372
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
373373
}
374-
375-
376-
# ------------------------------------------------------------------------------
377-
378-
#' @export
379-
#' @export min_grid.logistic_reg
380-
#' @rdname min_grid
381-
min_grid.logistic_reg <- min_grid.linear_reg

R/mars.R

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,11 @@ multi_predict._earth <-
232232
paste("Please use `keepxy = TRUE` as an option to enable submodel",
233233
"predictions with `earth`.")
234234
if (any(names(object$fit$call) == "keepxy")) {
235-
if(!isTRUE(object$fit$call$keepxy))
236-
stop (msg, call. = FALSE)
237-
} else
238-
stop (msg, call. = FALSE)
235+
if (!isTRUE(object$fit$call$keepxy))
236+
stop(msg, call. = FALSE)
237+
} else {
238+
stop(msg, call. = FALSE)
239+
}
239240

240241
if (is.null(type)) {
241242
if (object$spec$mode == "classification")
@@ -261,36 +262,3 @@ earth_by_terms <- function(num_terms, object, new_data, type, ...) {
261262
pred[[".row"]] <- 1:nrow(new_data)
262263
pred[, c(".row", "num_terms", nms)]
263264
}
264-
265-
# ------------------------------------------------------------------------------
266-
267-
#' @export
268-
#' @export min_grid.mars
269-
#' @rdname min_grid
270-
min_grid.mars <- function(x, grid, ...) {
271-
272-
grid_names <- names(grid)
273-
param_info <- get_submodel_info(x, grid)
274-
275-
if (!any(param_info$has_submodel)) {
276-
return(blank_submodels(grid))
277-
}
278-
279-
fixed_args <- get_fixed_args(param_info)
280-
281-
fit_only <-
282-
grid %>%
283-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
284-
dplyr::summarize(num_terms = max(num_terms, na.rm = TRUE)) %>%
285-
dplyr::ungroup()
286-
287-
min_grid_df <-
288-
dplyr::full_join(fit_only %>% rename(max_terms = num_terms), grid, by = fixed_args) %>%
289-
dplyr::filter(num_terms != max_terms) %>%
290-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
291-
dplyr::summarize(.submodels = list(list(num_terms = num_terms))) %>%
292-
dplyr::ungroup() %>%
293-
dplyr::full_join(fit_only, grid, by = fixed_args)
294-
295-
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
296-
}

R/misc.R

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,3 @@ terms_y <- function(x) {
232232
y_expr <- att$predvars[[resp_ind + 1]]
233233
all.vars(y_expr)
234234
}
235-
236-
237-
# ------------------------------------------------------------------------------
238-
239-
#'@export
240-
#'@rdname min_grid
241-
min_grid.model_spec <- function(x, grid, ...) {
242-
blank_submodels(grid)
243-
}
244-

R/multinom_reg.R

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -336,11 +336,3 @@ check_glmnet_lambda <- function(dat, object) {
336336
)
337337
dat
338338
}
339-
340-
341-
# ------------------------------------------------------------------------------
342-
343-
#' @export
344-
#' @export min_grid.multinom_reg
345-
#' @rdname min_grid
346-
min_grid.multinom_reg <- min_grid.linear_reg

R/nearest_neighbor.R

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -219,37 +219,3 @@ knn_by_k <- function(k, object, new_data, type, ...) {
219219
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
220220
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
221221
}
222-
223-
# ------------------------------------------------------------------------------
224-
225-
#' @export
226-
#' @export min_grid.nearest_neighbor
227-
#' @rdname min_grid
228-
min_grid.nearest_neighbor <- function(x, grid, ...) {
229-
230-
grid_names <- names(grid)
231-
param_info <- get_submodel_info(x, grid)
232-
233-
if (!any(param_info$has_submodel)) {
234-
return(blank_submodels(grid))
235-
}
236-
237-
fixed_args <- get_fixed_args(param_info)
238-
239-
fit_only <-
240-
grid %>%
241-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
242-
dplyr::summarize(neighbors = max(neighbors, na.rm = TRUE)) %>%
243-
dplyr::ungroup()
244-
245-
min_grid_df <-
246-
dplyr::full_join(fit_only %>% rename(max_neighbor = neighbors), grid, by = fixed_args) %>%
247-
dplyr::filter(neighbors != max_neighbor) %>%
248-
dplyr::rename(sub_neighbors = neighbors, neighbors = max_neighbor) %>%
249-
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
250-
dplyr::summarize(.submodels = list(list(neighbors = sub_neighbors))) %>%
251-
dplyr::ungroup() %>%
252-
dplyr::full_join(fit_only, grid, by = fixed_args)
253-
254-
min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
255-
}

man/min_grid.Rd

Lines changed: 0 additions & 47 deletions
This file was deleted.

0 commit comments

Comments
 (0)