Skip to content

Add engine specification field for predictor encodings #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
May 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
c1dbac0
initial work on #290
topepo Apr 30, 2020
5494a32
merge origin/master
juliasilge May 18, 2020
eeac337
Fix more tests for tidyr and tibble
juliasilge May 18, 2020
16a0c85
Fine-tune documentation
juliasilge May 18, 2020
99137ac
Engine encoding for logistic_reg()
juliasilge May 18, 2020
27246a6
Do not want inner names
juliasilge May 18, 2020
972a796
Indicator variables for MARS
juliasilge May 18, 2020
cd4ed77
Look up predictor indicator; use in convert_form_to_xy_fit()
juliasilge May 18, 2020
64bbde0
Test indicators = FALSE compared to a model that does not create indi…
juliasilge May 18, 2020
e02573f
Set predictor indicators for xgboost (TRUE) and C5.0 (FALSE)
juliasilge May 18, 2020
ad6ac8c
Set predictor encodings for Spark (TRUE).
juliasilge May 18, 2020
2c52d2c
Add glue. Closes #296.
juliasilge May 19, 2020
c6d0d35
For null model, set predictor indicators to... FALSE? :thinking:
juliasilge May 19, 2020
d3ee6de
Also need engine to find the indicator encoding
juliasilge May 19, 2020
47544b4
Decision tree predictors = FALSE
juliasilge May 19, 2020
ccee213
Neural nets all TRUE for indicators
juliasilge May 19, 2020
3780454
Predictor indicators for kknn
juliasilge May 19, 2020
1779777
Predictor indicators for multinomial classification
juliasilge May 19, 2020
7d87e45
Random forest predictor indicators
juliasilge May 19, 2020
172541c
Survival models make indicators
juliasilge May 19, 2020
115292d
Change kernlab to use formula interface, add indicator encoding
juliasilge May 19, 2020
2e8d113
Change svm_rbf (kernlab) to formula interface, add indicator encodings
juliasilge May 19, 2020
a420749
Update tests for kernlab formula interface
juliasilge May 19, 2020
3a3c134
Merge IT ALL
juliasilge May 26, 2020
3c8481e
Spark *always* makes indicator variables, fix dependency for Spark + …
juliasilge May 26, 2020
1160e1e
Fix function used with Spark decision tree for regression
juliasilge May 26, 2020
534987e
Change to predictor_indicators = FALSE for MARS models
juliasilge May 27, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ export(fit_control)
export(fit_xy)
export(fit_xy.model_spec)
export(get_dependency)
export(get_encoding)
export(get_fit)
export(get_from_env)
export(get_model_env)
Expand Down Expand Up @@ -146,6 +147,7 @@ export(repair_call)
export(rpart_train)
export(set_args)
export(set_dependency)
export(set_encoding)
export(set_engine)
export(set_env_val)
export(set_fit)
Expand Down
79 changes: 79 additions & 0 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,11 @@ check_interface_val <- function(x) {
#' below, depending on context.
#' @param pre,post Optional functions for pre- and post-processing of prediction
#' results.
#' @param options A list of options for engine-specific encodings. Currently,
#' the option implemented is `predictor_indicators` which tells `parsnip`
#' whether the pre-processing should make indicator/dummy variables from factor
#' predictors. This only affects cases when [fit.model_spec()] is used and the
#' underlying model has an x/y interface.
#' @param ... Optional arguments that should be passed into the `args` slot for
#' prediction objects.
#' @keywords internal
Expand Down Expand Up @@ -780,3 +785,77 @@ pred_value_template <- function(pre = NULL, post = NULL, func, ...) {
list(pre = pre, post = post, func = func, args = list(...))
}

# ------------------------------------------------------------------------------

check_encodings <- function(x) {
if (!is.list(x)) {
rlang::abort("`values` should be a list.")
}
req_args <- list(predictor_indicators = TRUE)

missing_args <- setdiff(names(req_args), names(x))
if (length(missing_args) > 0) {
rlang::abort(
glue::glue(
"The values passed to `set_encoding()` are missing arguments: ",
paste0("'", missing_args, "'", collapse = ", ")
)
)
}
extra_args <- setdiff(names(x), names(req_args))
if (length(extra_args) > 0) {
rlang::abort(
glue::glue(
"The values passed to `set_encoding()` had extra arguments: ",
paste0("'", extra_args, "'", collapse = ", ")
)
)
}
invisible(x)
}

#' @export
#' @rdname set_new_model
#' @keywords internal
set_encoding <- function(model, mode, eng, options) {
check_model_exists(model)
check_eng_val(eng)
check_mode_val(mode)
check_encodings(options)

keys <- tibble::tibble(model = model, engine = eng, mode = mode)
options <- tibble::as_tibble(options)
new_values <- dplyr::bind_cols(keys, options)


current_db_list <- ls(envir = get_model_env())
nm <- paste(model, "encoding", sep = "_")
if (any(current_db_list == nm)) {
current <- get_from_env(nm)
dup_check <-
current %>%
dplyr::inner_join(new_values, by = c("model", "engine", "mode", "predictor_indicators"))
if (nrow(dup_check)) {
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings."))
}

} else {
current <- NULL
}

db_values <- dplyr::bind_rows(current, new_values)
set_env_val(nm, db_values)

invisible(NULL)
}


#' @rdname set_new_model
#' @keywords internal
#' @export
get_encoding <- function(model) {
check_model_exists(model)
nm <- paste0(model, "_encoding")
rlang::env_get(get_model_env(), nm)
}

35 changes: 35 additions & 0 deletions R/boost_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ set_fit(
)
)

set_encoding(
model = "boost_tree",
eng = "xgboost",
mode = "regression",
options = list(predictor_indicators = TRUE)
)

set_pred(
model = "boost_tree",
eng = "xgboost",
Expand Down Expand Up @@ -125,6 +132,13 @@ set_fit(
)
)

set_encoding(
model = "boost_tree",
eng = "xgboost",
mode = "classification",
options = list(predictor_indicators = TRUE)
)

set_pred(
model = "boost_tree",
eng = "xgboost",
Expand Down Expand Up @@ -221,6 +235,13 @@ set_fit(
)
)

set_encoding(
model = "boost_tree",
eng = "C5.0",
mode = "classification",
options = list(predictor_indicators = FALSE)
)

set_pred(
model = "boost_tree",
eng = "C5.0",
Expand Down Expand Up @@ -344,6 +365,13 @@ set_fit(
)
)

set_encoding(
model = "boost_tree",
eng = "spark",
mode = "regression",
options = list(predictor_indicators = TRUE)
)

set_fit(
model = "boost_tree",
eng = "spark",
Expand All @@ -357,6 +385,13 @@ set_fit(
)
)

set_encoding(
model = "boost_tree",
eng = "spark",
mode = "classification",
options = list(predictor_indicators = TRUE)
)

set_pred(
model = "boost_tree",
eng = "spark",
Expand Down
2 changes: 1 addition & 1 deletion R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' @importFrom stats .checkMFClasses .getXlevels delete.response
#' @importFrom stats model.offset model.weights na.omit na.pass

convert_form_to_xy_fit <-function(
convert_form_to_xy_fit <- function(
formula,
data,
...,
Expand Down
39 changes: 37 additions & 2 deletions R/decision_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ set_fit(
)
)

set_encoding(
model = "decision_tree",
eng = "rpart",
mode = "regression",
options = list(predictor_indicators = FALSE)
)

set_fit(
model = "decision_tree",
eng = "rpart",
Expand All @@ -60,6 +67,13 @@ set_fit(
)
)

set_encoding(
model = "decision_tree",
eng = "rpart",
mode = "classification",
options = list(predictor_indicators = FALSE)
)

set_pred(
model = "decision_tree",
eng = "rpart",
Expand Down Expand Up @@ -158,6 +172,13 @@ set_fit(
)
)

set_encoding(
model = "decision_tree",
eng = "C5.0",
mode = "classification",
options = list(predictor_indicators = FALSE)
)

set_pred(
model = "decision_tree",
eng = "C5.0",
Expand Down Expand Up @@ -211,7 +232,7 @@ set_pred(

set_model_engine("decision_tree", "classification", "spark")
set_model_engine("decision_tree", "regression", "spark")
set_dependency("decision_tree", "spark", "spark")
set_dependency("decision_tree", "spark", "sparklyr")

set_model_arg(
model = "decision_tree",
Expand Down Expand Up @@ -239,12 +260,19 @@ set_fit(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula"),
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
func = c(pkg = "sparklyr", fun = "ml_decision_tree_regressor"),
defaults =
list(seed = expr(sample.int(10 ^ 5, 1)))
)
)

set_encoding(
model = "decision_tree",
eng = "spark",
mode = "regression",
options = list(predictor_indicators = TRUE)
)

set_fit(
model = "decision_tree",
eng = "spark",
Expand All @@ -259,6 +287,13 @@ set_fit(
)
)

set_encoding(
model = "decision_tree",
eng = "spark",
mode = "classification",
options = list(predictor_indicators = TRUE)
)

set_pred(
model = "decision_tree",
eng = "spark",
Expand Down
2 changes: 1 addition & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ fit.model_spec <-
eng_vals <- possible_engines(object)
object$engine <- eng_vals[1]
if (control$verbosity > 0) {
rlang::warn("Engine set to `{object$engine}`.")
rlang::warn(glue::glue("Engine set to `{object$engine}`."))
}
}

Expand Down
9 changes: 7 additions & 2 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,17 @@ xy_xy <- function(object, env, control, target = "none", ...) {
form_xy <- function(object, control, env,
target = "none", ...) {

indicators <- get_encoding(class(object)[1]) %>%
dplyr::filter(mode == object$mode,
engine == object$engine) %>%
dplyr::pull(predictor_indicators)

data_obj <- convert_form_to_xy_fit(
formula = env$formula,
data = env$data,
...,
composition = target
# indicators
composition = target,
indicators = indicators
)
env$x <- data_obj$x
env$y <- data_obj$y
Expand Down
Loading