Skip to content

Commit 592ab0e

Browse files
authored
Merge branch 'master' into master
2 parents 1393f96 + aab7e0c commit 592ab0e

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+930
-240
lines changed

.github/workflows/R-CMD-check.yaml

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ jobs:
1515
fail-fast: false
1616
matrix:
1717
config:
18-
- { os: windows-latest, r: '3.6'}
19-
- { os: windows-latest, r: '4.0'}
20-
- { os: windows-latest, r: 'devel'}
21-
- { os: ubuntu-16.04, r: '3.5', cran: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
22-
- { os: ubuntu-16.04, r: '3.6', cran: "https://demo.rstudiopm.com/all/__linux__/xenial/latest"}
18+
- {os: macOS-latest, r: 'devel'}
19+
- {os: macOS-latest, r: 'release'}
20+
- {os: windows-latest, r: 'release'}
21+
- {os: windows-latest, r: '3.6'}
22+
- {os: ubuntu-16.04, r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"}
23+
- {os: ubuntu-16.04, r: 'oldrel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"}
24+
- {os: ubuntu-16.04, r: '3.5', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"}
2325

2426
env:
2527
R_REMOTES_NO_ERRORS_FROM_WARNINGS: true
@@ -65,11 +67,20 @@ jobs:
6567
remotes::install_cran("rcmdcheck")
6668
shell: Rscript {0}
6769

68-
- name: Install TensorFlow
70+
- name: Install Miniconda
6971
run: |
72+
Rscript -e "remotes::install_github('rstudio/reticulate')"
7073
Rscript -e "reticulate::install_miniconda()"
71-
Rscript -e "reticulate::conda_create('r-reticulate', packages = 'python==3.6.9')"
72-
Rscript -e "tensorflow::install_tensorflow(version='1.14.0')"
74+
75+
- name: Find Miniconda on macOS
76+
if: runner.os == 'macOS'
77+
run: echo "options(reticulate.conda_binary = reticulate:::miniconda_conda())" >> .Rprofile
78+
79+
- name: Install TensorFlow
80+
run: |
81+
reticulate::conda_create('r-reticulate', packages = c('python==3.6.9'))
82+
tensorflow::install_tensorflow(version='1.14.0')
83+
shell: Rscript {0}
7384

7485
- name: Session info
7586
run: |

.github/workflows/test-coverage.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ jobs:
4343

4444
- name: Install TensorFlow
4545
run: |
46+
Rscript -e "remotes::install_github('rstudio/reticulate')"
4647
Rscript -e "reticulate::install_miniconda()"
48+
echo "options(reticulate.conda_binary = reticulate:::miniconda_conda())" >> .Rprofile
4749
Rscript -e "reticulate::conda_create('r-reticulate', packages = 'python==3.6.9')"
4850
Rscript -e "tensorflow::install_tensorflow(version='1.14.0')"
4951

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ export(add_rowindex)
101101
export(boost_tree)
102102
export(check_empty_ellipse)
103103
export(check_final_param)
104+
export(contr_one_hot)
104105
export(control_parsnip)
105106
export(convert_stan_interval)
106107
export(decision_tree)

NEWS.md

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
# parsnip (development version)
22

3+
## Breaking Changes
4+
5+
* `parsnip` now has options to set specific types of predictor encodings for different models. For example, `ranger` models run using `parsnip` and `workflows` do the same thing by _not_ creating indicator variables. These encodings can be overridden using the `blueprint` options in `workflows`. As a consequence, it is possible to get a different model fit that previous versions of `parsnip`. More details about specific encoding changes are below. (#326)
6+
37
## Other Changes
48

59
* `tidyr` >= 1.0.0 is now required.
610

7-
* SVM models produced by `kernlab` now use the formula method. This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
11+
* SVM models produced by `kernlab` now use the formula method (see breaking change notice above). This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
812

913
* MARS models produced by `earth` now use the formula method.
1014

11-
* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accomodated. (#315)
15+
* For `xgboost`, a one-hot encoding is used when indicator variables are created.
16+
17+
* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accommodated. (#315)
1218

1319
## New Features
1420

R/aaa.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ utils::globalVariables(
3939
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
4040
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
4141
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
42-
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators")
42+
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
43+
"compute_intercept", "remove_intercept")
4344
)
4445

4546
# nocov end

R/aaa_models.R

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,8 @@ check_interface_val <- function(x) {
323323
#' below, depending on context.
324324
#' @param pre,post Optional functions for pre- and post-processing of prediction
325325
#' results.
326-
#' @param options A list of options for engine-specific encodings. Currently,
327-
#' the option implemented is `predictor_indicators` which tells `parsnip`
328-
#' whether the pre-processing should make indicator/dummy variables from factor
329-
#' predictors. This only affects cases when [fit.model_spec()] is used and the
330-
#' underlying model has an x/y interface.
326+
#' @param options A list of options for engine-specific preprocessing encodings.
327+
#' See Details below.
331328
#' @param ... Optional arguments that should be passed into the `args` slot for
332329
#' prediction objects.
333330
#' @keywords internal
@@ -347,6 +344,36 @@ check_interface_val <- function(x) {
347344
#' already been registered. `check_model_doesnt_exist()` checks the model value
348345
#' and also checks to see if it is novel in the environment.
349346
#'
347+
#' The options for engine-specific encodings dictate how the predictors should be
348+
#' handled. These options ensure that the data
349+
#' that `parsnip` gives to the underlying model allows for a model fit that is
350+
#' as similar as possible to what it would have produced directly.
351+
#'
352+
#' For example, if `fit()` is used to fit a model that does not have
353+
#' a formula interface, typically some predictor preprocessing must
354+
#' be conducted. `glmnet` is a good example of this.
355+
#'
356+
#' There are three options that can be used for the encodings:
357+
#'
358+
#' `predictor_indicators` describes whether and how to create indicator/dummy
359+
#' variables from factor predictors. There are three options: `"none"` (do not
360+
#' expand factor predictors), `"traditional"` (apply the standard
361+
#' `model.matrix()` encodings), and `"one_hot"` (create the complete set
362+
#' including the baseline level for all factors). This encoding only affects
363+
#' cases when [fit.model_spec()] is used and the underlying model has an x/y
364+
#' interface.
365+
#'
366+
#' Another option is `compute_intercept`; this controls whether `model.matrix()`
367+
#' should include the intercept in its formula. This affects more than the
368+
#' inclusion of an intercept column. With an intercept, `model.matrix()`
369+
#' computes dummy variables for all but one factor levels. Without an
370+
#' intercept, `model.matrix()` computes a full set of indicators for the
371+
#' _first_ factor variable, but an incomplete set for the remainder.
372+
#'
373+
#' Finally, the option `remove_intercept` will remove the intercept column
374+
#' _after_ `model.matrix()` is finished. This can be useful if the model
375+
#' function (e.g. `lm()`) automatically generates an intercept.
376+
#'
350377
#' @references "Making a parsnip model from scratch"
351378
#' \url{https://tidymodels.github.io/parsnip/articles/articles/Scratch.html}
352379
#' @examples
@@ -791,7 +818,9 @@ check_encodings <- function(x) {
791818
if (!is.list(x)) {
792819
rlang::abort("`values` should be a list.")
793820
}
794-
req_args <- list(predictor_indicators = TRUE)
821+
req_args <- list(predictor_indicators = rlang::na_chr,
822+
compute_intercept = rlang::na_lgl,
823+
remove_intercept = rlang::na_lgl)
795824

796825
missing_args <- setdiff(names(req_args), names(x))
797826
if (length(missing_args) > 0) {
@@ -834,9 +863,12 @@ set_encoding <- function(model, mode, eng, options) {
834863
current <- get_from_env(nm)
835864
dup_check <-
836865
current %>%
837-
dplyr::inner_join(new_values, by = c("model", "engine", "mode", "predictor_indicators"))
866+
dplyr::inner_join(
867+
new_values,
868+
by = c("model", "engine", "mode", "predictor_indicators")
869+
)
838870
if (nrow(dup_check)) {
839-
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings."))
871+
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings for model '{model}'."))
840872
}
841873

842874
} else {
@@ -856,6 +888,19 @@ set_encoding <- function(model, mode, eng, options) {
856888
get_encoding <- function(model) {
857889
check_model_exists(model)
858890
nm <- paste0(model, "_encoding")
859-
rlang::env_get(get_model_env(), nm)
891+
res <- try(get_from_env(nm), silent = TRUE)
892+
if (inherits(res, "try-error")) {
893+
# for objects made before encodings were specified in parsnip
894+
res <-
895+
get_from_env(model) %>%
896+
dplyr::mutate(
897+
model = model,
898+
predictor_indicators = "traditional",
899+
compute_intercept = TRUE,
900+
remove_intercept = TRUE
901+
) %>%
902+
dplyr::select(model, engine, mode, predictor_indicators,
903+
compute_intercept, remove_intercept)
904+
}
905+
res
860906
}
861-

R/boost_tree.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ xgb_train <- function(
301301

302302

303303
if (is.numeric(y)) {
304-
loss <- "reg:linear"
304+
loss <- "reg:squarederror"
305305
} else {
306306
lvl <- levels(y)
307307
y <- as.numeric(y) - 1
@@ -399,7 +399,7 @@ xgb_pred <- function(object, newdata, ...) {
399399

400400
x = switch(
401401
object$params$objective,
402-
"reg:linear" = , "reg:logistic" = , "binary:logistic" = res,
402+
"reg:squarederror" = , "reg:logistic" = , "binary:logistic" = res,
403403
"binary:logitraw" = stats::binomial()$linkinv(res),
404404
"multi:softprob" = matrix(res, ncol = object$params$num_class, byrow = TRUE),
405405
res

R/boost_tree_data.R

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,11 @@ set_encoding(
9191
model = "boost_tree",
9292
eng = "xgboost",
9393
mode = "regression",
94-
options = list(predictor_indicators = TRUE)
94+
options = list(
95+
predictor_indicators = "one_hot",
96+
compute_intercept = FALSE,
97+
remove_intercept = TRUE
98+
)
9599
)
96100

97101
set_pred(
@@ -136,7 +140,11 @@ set_encoding(
136140
model = "boost_tree",
137141
eng = "xgboost",
138142
mode = "classification",
139-
options = list(predictor_indicators = TRUE)
143+
options = list(
144+
predictor_indicators = "one_hot",
145+
compute_intercept = FALSE,
146+
remove_intercept = TRUE
147+
)
140148
)
141149

142150
set_pred(
@@ -239,7 +247,11 @@ set_encoding(
239247
model = "boost_tree",
240248
eng = "C5.0",
241249
mode = "classification",
242-
options = list(predictor_indicators = FALSE)
250+
options = list(
251+
predictor_indicators = "none",
252+
compute_intercept = FALSE,
253+
remove_intercept = FALSE
254+
)
243255
)
244256

245257
set_pred(
@@ -369,7 +381,11 @@ set_encoding(
369381
model = "boost_tree",
370382
eng = "spark",
371383
mode = "regression",
372-
options = list(predictor_indicators = TRUE)
384+
options = list(
385+
predictor_indicators = "none",
386+
compute_intercept = FALSE,
387+
remove_intercept = FALSE
388+
)
373389
)
374390

375391
set_fit(
@@ -389,7 +405,11 @@ set_encoding(
389405
model = "boost_tree",
390406
eng = "spark",
391407
mode = "classification",
392-
options = list(predictor_indicators = TRUE)
408+
options = list(
409+
predictor_indicators = "none",
410+
compute_intercept = FALSE,
411+
remove_intercept = FALSE
412+
)
393413
)
394414

395415
set_pred(

R/contr_one_hot.R

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
#' Contrast function for one-hot encodings
2+
#'
3+
#' This contrast function produces a model matrix with indicator columns for
4+
#' each level of each factor.
5+
#'
6+
#' @param n A vector of character factor levels or the number of unique levels.
7+
#' @param contrasts This argument is for backwards compatibility and only the
8+
#' default of `TRUE` is supported.
9+
#' @param sparse This argument is for backwards compatibility and only the
10+
#' default of `FALSE` is supported.
11+
#'
12+
#' @includeRmd man/rmd/one-hot.Rmd details
13+
#'
14+
#' @return A diagonal matrix that is `n`-by-`n`.
15+
#'
16+
#' @export
17+
contr_one_hot <- function(n, contrasts = TRUE, sparse = FALSE) {
18+
if (sparse) {
19+
rlang::warn("`sparse = TRUE` not implemented for `contr_one_hot()`.")
20+
}
21+
22+
if (!contrasts) {
23+
rlang::warn("`contrasts = FALSE` not implemented for `contr_one_hot()`.")
24+
}
25+
26+
if (is.character(n)) {
27+
names <- n
28+
n <- length(names)
29+
} else if (is.numeric(n)) {
30+
n <- as.integer(n)
31+
32+
if (length(n) != 1L) {
33+
rlang::abort("`n` must have length 1 when an integer is provided.")
34+
}
35+
36+
names <- as.character(seq_len(n))
37+
} else {
38+
rlang::abort("`n` must be a character vector or an integer of size 1.")
39+
}
40+
41+
out <- diag(n)
42+
43+
rownames(out) <- names
44+
colnames(out) <- names
45+
46+
out
47+
}

0 commit comments

Comments
 (0)