Skip to content

Commit aab7e0c

Browse files
juliasilgetopepo
andauthored
Add argument for one hot encoding to parsnip (#332)
* Add one hot option to encoding options * one_hot = FALSE for almost all models, one_hot = TRUE for glmnet models * changed one_hot to logical; less confusing * revert glmnet encodings to one_hot * Switch from logical to none/traditional/one_hot * Update predictor_indicators in model infrastructure * change objective function name for xgboost regression * more encoding updates related to intercepts * set defaults for parsnip objects with no encoding information * "one-hot" not "one_hot" * apply encoding changes to form_xy and xy_form paths * fully export contrast function * "one_hot" not "one-hot" * fixed a few bugs * revert xgboost change (in another PR) * updated news * two more global variable false positives * updates for how many engines handle dummy variables (if at all) * details on encoding options * one_hot documentation * Update R/aaa_models.R Co-authored-by: Julia Silge <[email protected]> * Update R/aaa_models.R Co-authored-by: Julia Silge <[email protected]> * Update R/aaa_models.R Co-authored-by: Julia Silge <[email protected]> * Update R/aaa_models.R Co-authored-by: Julia Silge <[email protected]> * Update R/aaa_models.R Co-authored-by: Julia Silge <[email protected]> * Update R/contr_one_hot.R Co-authored-by: Julia Silge <[email protected]> * Update man/rmd/one-hot.Rmd Co-authored-by: Julia Silge <[email protected]> * Update man/rmd/one-hot.Rmd Co-authored-by: Julia Silge <[email protected]> * documentation updates for one-hot * Update man/rmd/one-hot.Rmd Co-authored-by: Julia Silge <[email protected]> * Update man/rmd/one-hot.Rmd Co-authored-by: Julia Silge <[email protected]> Co-authored-by: Max Kuhn <[email protected]>
1 parent 259749f commit aab7e0c

39 files changed

+889
-211
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ export(add_rowindex)
101101
export(boost_tree)
102102
export(check_empty_ellipse)
103103
export(check_final_param)
104+
export(contr_one_hot)
104105
export(control_parsnip)
105106
export(convert_stan_interval)
106107
export(decision_tree)

NEWS.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
# parsnip (development version)
22

3+
## Breaking Changes
4+
5+
* `parsnip` now has options to set specific types of predictor encodings for different models. For example, `ranger` models run using `parsnip` and `workflows` do the same thing by _not_ creating indicator variables. These encodings can be overridden using the `blueprint` options in `workflows`. As a consequence, it is possible to get a different model fit that previous versions of `parsnip`. More details about specific encoding changes are below. (#326)
6+
37
## Other Changes
48

59
* `tidyr` >= 1.0.0 is now required.
610

7-
* SVM models produced by `kernlab` now use the formula method. This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
11+
* SVM models produced by `kernlab` now use the formula method (see breaking change notice above). This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
812

913
* MARS models produced by `earth` now use the formula method.
1014

11-
* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accomodated. (#315)
15+
* For `xgboost`, a one-hot encoding is used when indicator variables are created.
16+
17+
* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accommodated. (#315)
1218

1319
## New Features
1420

R/aaa.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ utils::globalVariables(
3939
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
4040
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
4141
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
42-
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators")
42+
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
43+
"compute_intercept", "remove_intercept")
4344
)
4445

4546
# nocov end

R/aaa_models.R

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,8 @@ check_interface_val <- function(x) {
323323
#' below, depending on context.
324324
#' @param pre,post Optional functions for pre- and post-processing of prediction
325325
#' results.
326-
#' @param options A list of options for engine-specific encodings. Currently,
327-
#' the option implemented is `predictor_indicators` which tells `parsnip`
328-
#' whether the pre-processing should make indicator/dummy variables from factor
329-
#' predictors. This only affects cases when [fit.model_spec()] is used and the
330-
#' underlying model has an x/y interface.
326+
#' @param options A list of options for engine-specific preprocessing encodings.
327+
#' See Details below.
331328
#' @param ... Optional arguments that should be passed into the `args` slot for
332329
#' prediction objects.
333330
#' @keywords internal
@@ -347,6 +344,36 @@ check_interface_val <- function(x) {
347344
#' already been registered. `check_model_doesnt_exist()` checks the model value
348345
#' and also checks to see if it is novel in the environment.
349346
#'
347+
#' The options for engine-specific encodings dictate how the predictors should be
348+
#' handled. These options ensure that the data
349+
#' that `parsnip` gives to the underlying model allows for a model fit that is
350+
#' as similar as possible to what it would have produced directly.
351+
#'
352+
#' For example, if `fit()` is used to fit a model that does not have
353+
#' a formula interface, typically some predictor preprocessing must
354+
#' be conducted. `glmnet` is a good example of this.
355+
#'
356+
#' There are three options that can be used for the encodings:
357+
#'
358+
#' `predictor_indicators` describes whether and how to create indicator/dummy
359+
#' variables from factor predictors. There are three options: `"none"` (do not
360+
#' expand factor predictors), `"traditional"` (apply the standard
361+
#' `model.matrix()` encodings), and `"one_hot"` (create the complete set
362+
#' including the baseline level for all factors). This encoding only affects
363+
#' cases when [fit.model_spec()] is used and the underlying model has an x/y
364+
#' interface.
365+
#'
366+
#' Another option is `compute_intercept`; this controls whether `model.matrix()`
367+
#' should include the intercept in its formula. This affects more than the
368+
#' inclusion of an intercept column. With an intercept, `model.matrix()`
369+
#' computes dummy variables for all but one factor levels. Without an
370+
#' intercept, `model.matrix()` computes a full set of indicators for the
371+
#' _first_ factor variable, but an incomplete set for the remainder.
372+
#'
373+
#' Finally, the option `remove_intercept` will remove the intercept column
374+
#' _after_ `model.matrix()` is finished. This can be useful if the model
375+
#' function (e.g. `lm()`) automatically generates an intercept.
376+
#'
350377
#' @references "Making a parsnip model from scratch"
351378
#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html}
352379
#' @examples
@@ -791,7 +818,9 @@ check_encodings <- function(x) {
791818
if (!is.list(x)) {
792819
rlang::abort("`values` should be a list.")
793820
}
794-
req_args <- list(predictor_indicators = TRUE)
821+
req_args <- list(predictor_indicators = rlang::na_chr,
822+
compute_intercept = rlang::na_lgl,
823+
remove_intercept = rlang::na_lgl)
795824

796825
missing_args <- setdiff(names(req_args), names(x))
797826
if (length(missing_args) > 0) {
@@ -834,9 +863,12 @@ set_encoding <- function(model, mode, eng, options) {
834863
current <- get_from_env(nm)
835864
dup_check <-
836865
current %>%
837-
dplyr::inner_join(new_values, by = c("model", "engine", "mode", "predictor_indicators"))
866+
dplyr::inner_join(
867+
new_values,
868+
by = c("model", "engine", "mode", "predictor_indicators")
869+
)
838870
if (nrow(dup_check)) {
839-
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings."))
871+
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings for model '{model}'."))
840872
}
841873

842874
} else {
@@ -856,6 +888,19 @@ set_encoding <- function(model, mode, eng, options) {
856888
get_encoding <- function(model) {
857889
check_model_exists(model)
858890
nm <- paste0(model, "_encoding")
859-
rlang::env_get(get_model_env(), nm)
891+
res <- try(get_from_env(nm), silent = TRUE)
892+
if (inherits(res, "try-error")) {
893+
# for objects made before encodings were specified in parsnip
894+
res <-
895+
get_from_env(model) %>%
896+
dplyr::mutate(
897+
model = model,
898+
predictor_indicators = "traditional",
899+
compute_intercept = TRUE,
900+
remove_intercept = TRUE
901+
) %>%
902+
dplyr::select(model, engine, mode, predictor_indicators,
903+
compute_intercept, remove_intercept)
904+
}
905+
res
860906
}
861-

R/boost_tree_data.R

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ set_encoding(
9191
model = "boost_tree",
9292
eng = "xgboost",
9393
mode = "regression",
94-
options = list(predictor_indicators = TRUE)
94+
options = list(
95+
predictor_indicators = "one_hot",
96+
compute_intercept = FALSE,
97+
remove_intercept = TRUE
98+
)
9599
)
96100

97101
set_pred(
@@ -136,7 +140,11 @@ set_encoding(
136140
model = "boost_tree",
137141
eng = "xgboost",
138142
mode = "classification",
139-
options = list(predictor_indicators = TRUE)
143+
options = list(
144+
predictor_indicators = "one_hot",
145+
compute_intercept = FALSE,
146+
remove_intercept = TRUE
147+
)
140148
)
141149

142150
set_pred(
@@ -239,7 +247,11 @@ set_encoding(
239247
model = "boost_tree",
240248
eng = "C5.0",
241249
mode = "classification",
242-
options = list(predictor_indicators = FALSE)
250+
options = list(
251+
predictor_indicators = "none",
252+
compute_intercept = FALSE,
253+
remove_intercept = FALSE
254+
)
243255
)
244256

245257
set_pred(
@@ -369,7 +381,11 @@ set_encoding(
369381
model = "boost_tree",
370382
eng = "spark",
371383
mode = "regression",
372-
options = list(predictor_indicators = TRUE)
384+
options = list(
385+
predictor_indicators = "none",
386+
compute_intercept = FALSE,
387+
remove_intercept = FALSE
388+
)
373389
)
374390

375391
set_fit(
@@ -389,7 +405,11 @@ set_encoding(
389405
model = "boost_tree",
390406
eng = "spark",
391407
mode = "classification",
392-
options = list(predictor_indicators = TRUE)
408+
options = list(
409+
predictor_indicators = "none",
410+
compute_intercept = FALSE,
411+
remove_intercept = FALSE
412+
)
393413
)
394414

395415
set_pred(

R/contr_one_hot.R

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#' Contrast function for one-hot encodings
2+
#'
3+
#' This contrast function produces a model matrix with indicator columns for
4+
#' each level of each factor.
5+
#'
6+
#' @param n A vector of character factor levels or the number of unique levels.
7+
#' @param contrasts This argument is for backwards compatibility and only the
8+
#' default of `TRUE` is supported.
9+
#' @param sparse This argument is for backwards compatibility and only the
10+
#' default of `FALSE` is supported.
11+
#'
12+
#' @includeRmd man/rmd/one-hot.Rmd details
13+
#'
14+
#' @return A diagonal matrix that is `n`-by-`n`.
15+
#'
16+
#' @export
17+
contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
18+
if (sparse) {
19+
rlang::warn("`sparse = TRUE` not implemented for `contr_one_hot()`.")
20+
}
21+
22+
if (!contrasts) {
23+
rlang::warn("`contrasts = FALSE` not implemented for `contr_one_hot()`.")
24+
}
25+
26+
if (is.character(n)) {
27+
names <- n
28+
n <- length(names)
29+
} else if (is.numeric(n)) {
30+
n <- as.integer(n)
31+
32+
if (length(n) != 1L) {
33+
rlang::abort("`n` must have length 1 when an integer is provided.")
34+
}
35+
36+
names <- as.character(seq_len(n))
37+
} else {
38+
rlang::abort("`n` must be a character vector or an integer of size 1.")
39+
}
40+
41+
out <- diag(n)
42+
43+
rownames(out) <- names
44+
colnames(out) <- names
45+
46+
out
47+
}

R/convert_data.R

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ convert_form_to_xy_fit <- function(
2020
data,
2121
...,
2222
na.action = na.omit,
23-
indicators = TRUE,
24-
composition = "data.frame"
23+
indicators = "traditional",
24+
composition = "data.frame",
25+
remove_intercept = TRUE
2526
) {
2627
if (!(composition %in% c("data.frame", "matrix")))
2728
rlang::abort("`composition` should be either 'data.frame' or 'matrix'.")
@@ -72,8 +73,16 @@ convert_form_to_xy_fit <- function(
7273
)
7374
}
7475

75-
if (indicators) {
76+
if (indicators != "none") {
77+
if (indicators == "one_hot") {
78+
old_contr <- options("contrasts")$contrasts
79+
on.exit(options(contrasts = old_contr))
80+
new_contr <- old_contr
81+
new_contr["unordered"] <- "contr_one_hot"
82+
options(contrasts = new_contr)
83+
}
7684
x <- model.matrix(mod_terms, mod_frame, contrasts)
85+
7786
} else {
7887
# this still ignores -vars in formula
7988
x <- model.frame(mod_terms, data)
@@ -82,14 +91,15 @@ convert_form_to_xy_fit <- function(
8291
x <- x[,-y_cols, drop = FALSE]
8392
}
8493

85-
## TODO maybe an option not to do this?
86-
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]
87-
94+
if (remove_intercept) {
95+
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]
96+
}
8897
options <-
8998
list(
9099
indicators = indicators,
91100
composition = composition,
92-
contrasts = contrasts
101+
contrasts = contrasts,
102+
remove_intercept = remove_intercept
93103
)
94104

95105
if (composition == "data.frame") {
@@ -165,12 +175,21 @@ convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,
165175
if (!is.null(cl))
166176
.checkMFClasses(cl, new_data)
167177

168-
if(object$options$indicators) {
178+
if(object$options$indicators != "none") {
179+
if (object$options$indicators == "one_hot") {
180+
old_contr <- options("contrasts")$contrasts
181+
on.exit(options(contrasts = old_contr))
182+
new_contr <- old_contr
183+
new_contr["unordered"] <- "contr_one_hot"
184+
options(contrasts = new_contr)
185+
}
169186
new_data <-
170187
model.matrix(mod_terms, new_data, contrasts.arg = object$contrasts)
171188
}
172189

173-
new_data <- new_data[, colnames(new_data) != "(Intercept)", drop = FALSE]
190+
if(object$options$remove_intercept) {
191+
new_data <- new_data[, colnames(new_data) != "(Intercept)", drop = FALSE]
192+
}
174193

175194
if (composition == "data.frame")
176195
new_data <- as.data.frame(new_data)
@@ -188,10 +207,15 @@ convert_form_to_xy_new <- function(object, new_data, na.action = na.pass,
188207

189208
#' @importFrom dplyr bind_cols
190209
# TODO slots for other roles
191-
convert_xy_to_form_fit <- function(x, y, weights = NULL, y_name = "..y") {
210+
convert_xy_to_form_fit <- function(x, y, weights = NULL, y_name = "..y",
211+
remove_intercept = TRUE) {
192212
if (is.vector(x))
193213
rlang::abort("`x` cannot be a vector.")
194214

215+
if(remove_intercept) {
216+
x <- x[, colnames(x) != "(Intercept)", drop = FALSE]
217+
}
218+
195219
rn <- rownames(x)
196220

197221
if (!is.data.frame(x))

0 commit comments

Comments
 (0)