Skip to content

Add argument for one hot encoding to parsnip #332

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Jul 2, 2020
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
8a3b4b7
Add one hot option to encoding options
juliasilge Jun 18, 2020
3a3743e
one_hot = FALSE for almost all models, one_hot = TRUE for glmnet models
juliasilge Jun 18, 2020
1f4c40b
changed one_hot to logical; less confusing
topepo Jun 24, 2020
4c6641c
revert glmnet encodings to one_hot
topepo Jun 24, 2020
e169f8f
Switch from logical to none/traditional/one_hot
juliasilge Jun 26, 2020
51c18ad
Update predictor_indicators in model infrastructure
juliasilge Jun 26, 2020
3420156
change objective function name for xgboost regression
topepo Jun 30, 2020
2430518
more encoding updates related to intercepts
topepo Jun 30, 2020
00e3180
set defaults for parsnip objects with no encoding information
topepo Jun 30, 2020
91cc98d
"one-hot" not "one_hot"
topepo Jun 30, 2020
cb68875
apply encoding changes to form_xy and xy_form paths
topepo Jun 30, 2020
3ebd066
fully export contrast function
topepo Jun 30, 2020
c30a50a
"one_hot" not "one-hot"
topepo Jun 30, 2020
9c7df98
fixed a few bugs
topepo Jul 1, 2020
164c4d3
revert xgboost change (in another PR)
topepo Jul 1, 2020
ac2aa17
updated news
topepo Jul 1, 2020
856c829
two more global variable false positives
topepo Jul 1, 2020
8503ae1
updates for how many engines handle dummy variables (if at all)
topepo Jul 1, 2020
c76ec17
details on encoding options
topepo Jul 1, 2020
d7eee45
one_hot documentation
topepo Jul 1, 2020
a2308d9
Update R/aaa_models.R
topepo Jul 1, 2020
7318d7f
Update R/aaa_models.R
topepo Jul 1, 2020
334f01c
Update R/aaa_models.R
topepo Jul 1, 2020
110ca67
Update R/aaa_models.R
topepo Jul 1, 2020
d70e414
Update R/aaa_models.R
topepo Jul 1, 2020
ea3ec8c
Update R/contr_one_hot.R
topepo Jul 1, 2020
fc4f165
Update man/rmd/one-hot.Rmd
topepo Jul 1, 2020
9a11306
Update man/rmd/one-hot.Rmd
topepo Jul 1, 2020
ccd52bb
documentation updates for one-hot
topepo Jul 1, 2020
d04b892
Update man/rmd/one-hot.Rmd
topepo Jul 2, 2020
f164247
Update man/rmd/one-hot.Rmd
topepo Jul 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ export(add_rowindex)
export(boost_tree)
export(check_empty_ellipse)
export(check_final_param)
export(contr_one_hot)
export(control_parsnip)
export(convert_stan_interval)
export(decision_tree)
Expand Down
10 changes: 8 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# parsnip (development version)

## Breaking Changes

* `parsnip` now has options to set specific types of predictor encodings for different models. For example, `ranger` models run using `parsnip` and `workflows` do the same thing by _not_ creating indicator variables. These encodings can be overridden using the `blueprint` options in `workflows`. As a consequence, it is possible to get a different model fit that previous versions of `parsnip`. More details about specific encoding changes are below. (#326)

## Other Changes

* `tidyr` >= 1.0.0 is now required.

* SVM models produced by `kernlab` now use the formula method. This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
* SVM models produced by `kernlab` now use the formula method (see breaking change notice above). This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.

* MARS models produced by `earth` now use the formula method.

* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accomodated. (#315)
* For `xgboost`, a one-hot encoding is used when indicator variables are created.

* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accommodated. (#315)

## New Features

Expand Down
3 changes: 2 additions & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ utils::globalVariables(
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators")
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
"compute_intercept", "remove_intercept")
)

# nocov end
65 changes: 55 additions & 10 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -323,11 +323,8 @@ check_interface_val <- function(x) {
#' below, depending on context.
#' @param pre,post Optional functions for pre- and post-processing of prediction
#' results.
#' @param options A list of options for engine-specific encodings. Currently,
#' the option implemented is `predictor_indicators` which tells `parsnip`
#' whether the pre-processing should make indicator/dummy variables from factor
#' predictors. This only affects cases when [fit.model_spec()] is used and the
#' underlying model has an x/y interface.
#' @param options A list of options for engine-specific preprocessing encodings.
#' See Details below.
#' @param ... Optional arguments that should be passed into the `args` slot for
#' prediction objects.
#' @keywords internal
Expand All @@ -347,6 +344,36 @@ check_interface_val <- function(x) {
#' already been registered. `check_model_doesnt_exist()` checks the model value
#' and also checks to see if it is novel in the environment.
#'
#' The options for engine-specific encodings dictate how the predictors should be
#' handled. These options ensure that the data
#' that `parsnip` gives to the underlying model allows for a model fit that is
#' as similar as possible to what it would have produced directly.
#'
#' For example, if `fit()` is used to fit a model that does not have
#' a formula interface, typically some predictor preprocessing must
#' be conducted. `glmnet` is a good example of this.
#'
#' There are three options that can be used for the encodings:
#'
#' `predictor_indicators` describes whether and how to create indicator/dummy
#' variables from factor predictors. There are three options: `"none"` (do not
#' expand factor predictors), `"traditional"` (apply the standard
#' `model.matrix()` encodings), and `"one_hot"` (create the complete set
#' including the baseline level for all factors). This encoding only affects
#' cases when [fit.model_spec()] is used and the underlying model has an x/y
#' interface.
#'
#' Another option is `compute_intercept`; this controls whether `model.matrix()`
#' should include the intercept in its formula. This affects more than the
#' inclusion of an intercept column. With an intercept, `model.matrix()`
#' computes dummy variables for all but one factor levels. Without an
#' intercept, `model.matrix()` computes a full set of indicators for the
#' _first_ factor variable, but an incomplete set for the remainder.
#'
#' Finally, the option `remove_intercept` will remove the intercept column
#' _after_ `model.matrix()` is finished. This can be useful if the model
#' function (e.g. `lm()`) automatically generates an intercept.
#'
#' @references "Making a parsnip model from scratch"
#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html}
#' @examples
Expand Down Expand Up @@ -791,7 +818,9 @@ check_encodings <- function(x) {
if (!is.list(x)) {
rlang::abort("`values` should be a list.")
}
req_args <- list(predictor_indicators = TRUE)
req_args <- list(predictor_indicators = rlang::na_chr,
compute_intercept = rlang::na_lgl,
remove_intercept = rlang::na_lgl)

missing_args <- setdiff(names(req_args), names(x))
if (length(missing_args) > 0) {
Expand Down Expand Up @@ -834,9 +863,12 @@ set_encoding <- function(model, mode, eng, options) {
current <- get_from_env(nm)
dup_check <-
current %>%
dplyr::inner_join(new_values, by = c("model", "engine", "mode", "predictor_indicators"))
dplyr::inner_join(
new_values,
by = c("model", "engine", "mode", "predictor_indicators")
)
if (nrow(dup_check)) {
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings."))
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings for model '{model}'."))
}

} else {
Expand All @@ -856,6 +888,19 @@ set_encoding <- function(model, mode, eng, options) {
get_encoding <- function(model) {
check_model_exists(model)
nm <- paste0(model, "_encoding")
rlang::env_get(get_model_env(), nm)
res <- try(get_from_env(nm), silent = TRUE)
if (inherits(res, "try-error")) {
# for objects made before encodings were specified in parsnip
res <-
get_from_env(model) %>%
dplyr::mutate(
model = model,
predictor_indicators = "traditional",
compute_intercept = TRUE,
remove_intercept = TRUE
) %>%
dplyr::select(model, engine, mode, predictor_indicators,
compute_intercept, remove_intercept)
}
res
}

30 changes: 25 additions & 5 deletions R/boost_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,11 @@ set_encoding(
model = "boost_tree",
eng = "xgboost",
mode = "regression",
options = list(predictor_indicators = TRUE)
options = list(
predictor_indicators = "one_hot",
compute_intercept = FALSE,
remove_intercept = TRUE
)
)

set_pred(
Expand Down Expand Up @@ -136,7 +140,11 @@ set_encoding(
model = "boost_tree",
eng = "xgboost",
mode = "classification",
options = list(predictor_indicators = TRUE)
options = list(
predictor_indicators = "one_hot",
compute_intercept = FALSE,
remove_intercept = TRUE
)
)

set_pred(
Expand Down Expand Up @@ -239,7 +247,11 @@ set_encoding(
model = "boost_tree",
eng = "C5.0",
mode = "classification",
options = list(predictor_indicators = FALSE)
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE
)
)

set_pred(
Expand Down Expand Up @@ -369,7 +381,11 @@ set_encoding(
model = "boost_tree",
eng = "spark",
mode = "regression",
options = list(predictor_indicators = TRUE)
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE
)
)

set_fit(
Expand All @@ -389,7 +405,11 @@ set_encoding(
model = "boost_tree",
eng = "spark",
mode = "classification",
options = list(predictor_indicators = TRUE)
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE
)
)

set_pred(
Expand Down
47 changes: 47 additions & 0 deletions R/contr_one_hot.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#' Contrast function for one-hot encodings
#'
#' This contrast function produces a model matrix with indicator columns for
#' each level of each factor.
#'
#' @param n A vector of character factor levels or the number of unique levels.
#' @param contrasts This argument is for backwards compatibility and only the
#' default of `TRUE` is supported.
#' @param sparse This argument is for backwards compatibility and only the
#' default of `FALSE` is supported.
#'
#' @includeRmd man/rmd/one-hot.Rmd details
#'
#' @return A diagonal matrix that is `n`-by-`n`.
#'
#' @export
contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
if (sparse) {
rlang::warn("`sparse = TRUE` not implemented for `contr_one_hot()`.")
}

if (!contrasts) {
rlang::warn("`contrasts = FALSE` not implemented for `contr_one_hot()`.")
}

if (is.character(n)) {
names <- n
n <- length(names)
} else if (is.numeric(n)) {
n <- as.integer(n)

if (length(n) != 1L) {
rlang::abort("`n` must have length 1 when an integer is provided.")
}

names <- as.character(seq_len(n))
} else {
rlang::abort("`n` must be a character vector or an integer of size 1.")
}

out <- diag(n)

rownames(out) <- names
colnames(out) <- names

out
}
44 changes: 34 additions & 10 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ convert_form_to_xy_fit <- function(
data,
...,
na.action = na.omit,
indicators = TRUE,
composition = "data.frame"
indicators = "traditional",
composition = "data.frame",
remove_intercept = TRUE
) {
if (!(composition %in% c("data.frame", "matrix")))
rlang::abort("`composition` should be either 'data.frame' or 'matrix'.")
Expand Down Expand Up @@ -72,8 +73,16 @@ convert_form_to_xy_fit <- function(
)
}

if (indicators) {
if (indicators != "none") {
if (indicators == "one_hot") {
old_contr <- options("contrasts")$contrasts
on.exit(options(contrasts = old_contr))
new_contr <- old_contr
new_contr["unordered"] <- "contr_one_hot"
options(contrasts = new_contr)
}
x <- model.matrix(mod_terms, mod_frame, contrasts)

} else {
# this still ignores -vars in formula
x <- model.frame(mod_terms, data)
Expand All @@ -82,14 +91,15 @@ convert_form_to_xy_fit <- function(
x <- x[,-y_cols, drop = FALSE]
}

## TODO maybe an option not to do this?
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]

if (remove_intercept) {
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]
}
options <-
list(
indicators = indicators,
composition = composition,
contrasts = contrasts
contrasts = contrasts,
remove_intercept = remove_intercept
)

if (composition == "data.frame") {
Expand Down Expand Up @@ -165,12 +175,21 @@ convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,
if (!is.null(cl))
.checkMFClasses(cl, new_data)

if(object$options$indicators) {
if(object$options$indicators != "none") {
if (object$options$indicators == "one_hot") {
old_contr <- options("contrasts")$contrasts
on.exit(options(contrasts = old_contr))
new_contr <- old_contr
new_contr["unordered"] <- "contr_one_hot"
options(contrasts = new_contr)
}
new_data <-
model.matrix(mod_terms, new_data, contrasts.arg = object$contrasts)
}

new_data <- new_data[, colnames(new_data) != "(Intercept)", drop = FALSE]
if(object$options$remove_intercept) {
new_data <- new_data[, colnames(new_data) != "(Intercept)", drop = FALSE]
}

if (composition == "data.frame")
new_data <- as.data.frame(new_data)
Expand All @@ -188,10 +207,15 @@ convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,

#' @importFrom dplyr bind_cols
# TODO slots for other roles
convert_xy_to_form_fit <- function(x, y, weights = NULL, y_name = "..y") {
convert_xy_to_form_fit <- function(x, y, weights = NULL, y_name = "..y",
remove_intercept = TRUE) {
if (is.vector(x))
rlang::abort("`x` cannot be a vector.")

if(remove_intercept) {
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]
}

rn <- rownames(x)

if (!is.data.frame(x))
Expand Down
Loading