Skip to content

removed min_grid generic, methods, and functions; moved to tune #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,6 @@ S3method(fit_xy,model_spec)
S3method(has_multi_predict,default)
S3method(has_multi_predict,model_fit)
S3method(has_multi_predict,workflow)
S3method(min_grid,boost_tree)
S3method(min_grid,linear_reg)
S3method(min_grid,logistic_reg)
S3method(min_grid,mars)
S3method(min_grid,model_spec)
S3method(min_grid,multinom_reg)
S3method(min_grid,nearest_neighbor)
S3method(multi_predict,"_C5.0")
S3method(multi_predict,"_earth")
S3method(multi_predict,"_elnet")
Expand Down Expand Up @@ -114,13 +107,6 @@ export(linear_reg)
export(logistic_reg)
export(make_classes)
export(mars)
export(min_grid)
export(min_grid.boost_tree)
export(min_grid.linear_reg)
export(min_grid.logistic_reg)
export(min_grid.mars)
export(min_grid.multinom_reg)
export(min_grid.nearest_neighbor)
export(mlp)
export(model_printer)
export(multi_predict)
Expand Down
87 changes: 0 additions & 87 deletions R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,93 +24,6 @@ convert_stan_interval <- function(x, level = 0.95, lower = TRUE) {
res
}

# ------------------------------------------------------------------------------
# min_grid generic - put here so that the generic shows up first in the man file

#' Determine the minimum set of model fits
#'
#' `min_grid` determines exactly what models should be fit in order to
#' evaluate the entire set of tuning parameter combinations. This is for
#' internal use only and the API may change in the near future.
#' @param x A model specification.
#' @param grid A tibble with tuning parameter combinations.
#' @param ... Not currently used.
#' @return A tibble with the minimum tuning parameters to fit and an additional
#' list column with the parameter combinations used for prediction.
#' @keywords internal
#' @export
min_grid <- function(x, grid, ...) {
# x is a `model_spec` object from parsnip
# grid is a tibble of tuning parameter values with names
# matching the parameter names.
UseMethod("min_grid")
}

# As an example, if we fit a boosted tree model and tune over
# trees = 1:20 and min_n = c(20, 30)
# we should only have to fit two models:
#
# trees = 20 & min_n = 20
# trees = 20 & min_n = 30
#
# The logic related to how this "mini grid" gets made is model-specific.
#
# To get the full set of predictions, we need to know, for each of these two
# models, what values of num_terms to give to the multi_predict() function.
#
# The current idea is to have a list column of the extra models for prediction.
# For the example above:
#
# # A tibble: 2 x 3
# trees min_n .submodels
# <dbl> <dbl> <list>
# 1 20 20 <named list [1]>
# 2 20 30 <named list [1]>
#
# and the .submodels would both be
#
# list(trees = 1:19)
#
# There are a lot of other things to consider in future versions like grids
# where there are multiple columns with the same name (maybe the results of
# a recipe) and so on.

# ------------------------------------------------------------------------------
# helper functions

# Template for model results that do no have the sub-model feature
blank_submodels <- function(grid) {
grid %>%
dplyr::mutate(.submodels = map(1:nrow(grid), ~ list()))
}

get_fixed_args <- function(info) {
# Get non-sub-model columns to iterate over
fixed_args <- info$name[!info$has_submodel]
}

get_submodel_info <- function(spec, grid) {
param_info <-
get_from_env(paste0(class(spec)[1], "_args")) %>%
dplyr::filter(engine == spec$engine) %>%
dplyr::select(name = parsnip, has_submodel)

# In case a recipe or other activity has grid parameter columns,
# add those to the results
grid_names <- names(grid)
is_mod_param <- grid_names %in% param_info$name
if (any(!is_mod_param)) {
param_info <-
param_info %>%
dplyr::bind_rows(
tibble::tibble(name = grid_names[!is_mod_param],
has_submodel = FALSE)
)
}
param_info %>% dplyr::filter(name %in% grid_names)
}


# ------------------------------------------------------------------------------
# nocov

Expand Down
37 changes: 0 additions & 37 deletions R/boost_tree.R
Original file line number Diff line number Diff line change
Expand Up @@ -514,41 +514,4 @@ C50_by_tree <- function(tree, object, new_data, type, ...) {
pred[, c(".row", "trees", nms)]
}

# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.boost_tree
#' @rdname min_grid
min_grid.boost_tree <- function(x, grid, ...) {
grid_names <- names(grid)
param_info <- get_submodel_info(x, grid)

# No ability to do submodels? Finish here:
if (!any(param_info$has_submodel)) {
return(blank_submodels(grid))
}

fixed_args <- get_fixed_args(param_info)

# For boosted trees, fit the model with the most trees (conditional on the
# other parameters) so that you can do predictions on the smaller models.
fit_only <-
grid %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(trees = max(trees, na.rm = TRUE)) %>%
dplyr::ungroup()

# Add a column .submodels that is a list with what should be predicted
# by `multi_predict()` (assuming `predict()` has already been executed
# on the original value of 'trees')
min_grid_df <-
dplyr::full_join(fit_only %>% rename(max_tree = trees), grid, by = fixed_args) %>%
dplyr::filter(trees != max_tree) %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(.submodels = list(list(trees = trees))) %>%
dplyr::ungroup() %>%
dplyr::full_join(fit_only, grid, by = fixed_args)

min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
}

34 changes: 0 additions & 34 deletions R/linear_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -359,37 +359,3 @@ multi_predict._elnet <-
names(pred) <- NULL
tibble(.pred = pred)
}


# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.linear_reg
#' @rdname min_grid
min_grid.linear_reg <- function(x, grid, ...) {

grid_names <- names(grid)
param_info <- get_submodel_info(x, grid)

if (!any(param_info$has_submodel)) {
return(blank_submodels(grid))
}

fixed_args <- get_fixed_args(param_info)

fit_only <-
grid %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(penalty = max(penalty, na.rm = TRUE)) %>%
dplyr::ungroup()

min_grid_df <-
dplyr::full_join(fit_only %>% rename(max_penalty = penalty), grid, by = fixed_args) %>%
dplyr::filter(penalty != max_penalty) %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(.submodels = list(list(penalty = penalty))) %>%
dplyr::ungroup() %>%
dplyr::full_join(fit_only, grid, by = fixed_args)

min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
}
8 changes: 0 additions & 8 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,3 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
object$spec <- eval_args(object$spec)
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
}


# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.logistic_reg
#' @rdname min_grid
min_grid.logistic_reg <- min_grid.linear_reg
42 changes: 5 additions & 37 deletions R/mars.R
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,11 @@ multi_predict._earth <-
paste("Please use `keepxy = TRUE` as an option to enable submodel",
"predictions with `earth`.")
if (any(names(object$fit$call) == "keepxy")) {
if(!isTRUE(object$fit$call$keepxy))
stop (msg, call. = FALSE)
} else
stop (msg, call. = FALSE)
if (!isTRUE(object$fit$call$keepxy))
stop(msg, call. = FALSE)
} else {
stop(msg, call. = FALSE)
}

if (is.null(type)) {
if (object$spec$mode == "classification")
Expand All @@ -261,36 +262,3 @@ earth_by_terms <- function(num_terms, object, new_data, type, ...) {
pred[[".row"]] <- 1:nrow(new_data)
pred[, c(".row", "num_terms", nms)]
}

# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.mars
#' @rdname min_grid
min_grid.mars <- function(x, grid, ...) {

grid_names <- names(grid)
param_info <- get_submodel_info(x, grid)

if (!any(param_info$has_submodel)) {
return(blank_submodels(grid))
}

fixed_args <- get_fixed_args(param_info)

fit_only <-
grid %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(num_terms = max(num_terms, na.rm = TRUE)) %>%
dplyr::ungroup()

min_grid_df <-
dplyr::full_join(fit_only %>% rename(max_terms = num_terms), grid, by = fixed_args) %>%
dplyr::filter(num_terms != max_terms) %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(.submodels = list(list(num_terms = num_terms))) %>%
dplyr::ungroup() %>%
dplyr::full_join(fit_only, grid, by = fixed_args)

min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
}
10 changes: 0 additions & 10 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,3 @@ terms_y <- function(x) {
y_expr <- att$predvars[[resp_ind + 1]]
all.vars(y_expr)
}


# ------------------------------------------------------------------------------

#'@export
#'@rdname min_grid
min_grid.model_spec <- function(x, grid, ...) {
blank_submodels(grid)
}

8 changes: 0 additions & 8 deletions R/multinom_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,3 @@ check_glmnet_lambda <- function(dat, object) {
)
dat
}


# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.multinom_reg
#' @rdname min_grid
min_grid.multinom_reg <- min_grid.linear_reg
34 changes: 0 additions & 34 deletions R/nearest_neighbor.R
Original file line number Diff line number Diff line change
Expand Up @@ -219,37 +219,3 @@ knn_by_k <- function(k, object, new_data, type, ...) {
dplyr::mutate(neighbors = k, .row = dplyr::row_number()) %>%
dplyr::select(.row, neighbors, dplyr::starts_with(".pred"))
}

# ------------------------------------------------------------------------------

#' @export
#' @export min_grid.nearest_neighbor
#' @rdname min_grid
min_grid.nearest_neighbor <- function(x, grid, ...) {

grid_names <- names(grid)
param_info <- get_submodel_info(x, grid)

if (!any(param_info$has_submodel)) {
return(blank_submodels(grid))
}

fixed_args <- get_fixed_args(param_info)

fit_only <-
grid %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(neighbors = max(neighbors, na.rm = TRUE)) %>%
dplyr::ungroup()

min_grid_df <-
dplyr::full_join(fit_only %>% rename(max_neighbor = neighbors), grid, by = fixed_args) %>%
dplyr::filter(neighbors != max_neighbor) %>%
dplyr::rename(sub_neighbors = neighbors, neighbors = max_neighbor) %>%
dplyr::group_by(!!!rlang::syms(fixed_args)) %>%
dplyr::summarize(.submodels = list(list(neighbors = sub_neighbors))) %>%
dplyr::ungroup() %>%
dplyr::full_join(fit_only, grid, by = fixed_args)

min_grid_df %>% dplyr::select(dplyr::one_of(grid_names), .submodels)
}
47 changes: 0 additions & 47 deletions man/min_grid.Rd

This file was deleted.

Loading