Skip to content

Commit aa29bac

Browse files
authored
Merge pull request #319 from tidymodels/encoding-options
Add engine specification field for predictor encodings
2 parents ab221e1 + 534987e commit aa29bac

20 files changed

+503
-54
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ export(fit_control)
110110
export(fit_xy)
111111
export(fit_xy.model_spec)
112112
export(get_dependency)
113+
export(get_encoding)
113114
export(get_fit)
114115
export(get_from_env)
115116
export(get_model_env)
@@ -146,6 +147,7 @@ export(repair_call)
146147
export(rpart_train)
147148
export(set_args)
148149
export(set_dependency)
150+
export(set_encoding)
149151
export(set_engine)
150152
export(set_env_val)
151153
export(set_fit)

R/aaa_models.R

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,11 @@ check_interface_val <- function(x) {
323323
#' below, depending on context.
324324
#' @param pre,post Optional functions for pre- and post-processing of prediction
325325
#' results.
326+
#' @param options A list of options for engine-specific encodings. Currently,
327+
#' the option implemented is `predictor_indicators` which tells `parsnip`
328+
#' whether the pre-processing should make indicator/dummy variables from factor
329+
#' predictors. This only affects cases when [fit.model_spec()] is used and the
330+
#' underlying model has an x/y interface.
326331
#' @param ... Optional arguments that should be passed into the `args` slot for
327332
#' prediction objects.
328333
#' @keywords internal
@@ -780,3 +785,77 @@ pred_value_template <- function(pre = NULL, post = NULL, func, ...) {
780785
list(pre = pre, post = post, func = func, args = list(...))
781786
}
782787

788+
# ------------------------------------------------------------------------------
789+
790+
check_encodings <- function(x) {
791+
if (!is.list(x)) {
792+
rlang::abort("`values` should be a list.")
793+
}
794+
req_args <- list(predictor_indicators = TRUE)
795+
796+
missing_args <- setdiff(names(req_args), names(x))
797+
if (length(missing_args) > 0) {
798+
rlang::abort(
799+
glue::glue(
800+
"The values passed to `set_encoding()` are missing arguments: ",
801+
paste0("'", missing_args, "'", collapse = ", ")
802+
)
803+
)
804+
}
805+
extra_args <- setdiff(names(x), names(req_args))
806+
if (length(extra_args) > 0) {
807+
rlang::abort(
808+
glue::glue(
809+
"The values passed to `set_encoding()` had extra arguments: ",
810+
paste0("'", extra_args, "'", collapse = ", ")
811+
)
812+
)
813+
}
814+
invisible(x)
815+
}
816+
817+
#' @export
818+
#' @rdname set_new_model
819+
#' @keywords internal
820+
set_encoding <- function(model, mode, eng, options) {
821+
check_model_exists(model)
822+
check_eng_val(eng)
823+
check_mode_val(mode)
824+
check_encodings(options)
825+
826+
keys <- tibble::tibble(model = model, engine = eng, mode = mode)
827+
options <- tibble::as_tibble(options)
828+
new_values <- dplyr::bind_cols(keys, options)
829+
830+
831+
current_db_list <- ls(envir = get_model_env())
832+
nm <- paste(model, "encoding", sep = "_")
833+
if (any(current_db_list == nm)) {
834+
current <- get_from_env(nm)
835+
dup_check <-
836+
current %>%
837+
dplyr::inner_join(new_values, by = c("model", "engine", "mode", "predictor_indicators"))
838+
if (nrow(dup_check)) {
839+
rlang::abort(glue::glue("Engine '{eng}' and mode '{mode}' already have defined encodings."))
840+
}
841+
842+
} else {
843+
current <- NULL
844+
}
845+
846+
db_values <- dplyr::bind_rows(current, new_values)
847+
set_env_val(nm, db_values)
848+
849+
invisible(NULL)
850+
}
851+
852+
853+
#' @rdname set_new_model
854+
#' @keywords internal
855+
#' @export
856+
get_encoding <- function(model) {
857+
check_model_exists(model)
858+
nm <- paste0(model, "_encoding")
859+
rlang::env_get(get_model_env(), nm)
860+
}
861+

R/boost_tree_data.R

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,13 @@ set_fit(
8787
)
8888
)
8989

90+
set_encoding(
91+
model = "boost_tree",
92+
eng = "xgboost",
93+
mode = "regression",
94+
options = list(predictor_indicators = TRUE)
95+
)
96+
9097
set_pred(
9198
model = "boost_tree",
9299
eng = "xgboost",
@@ -125,6 +132,13 @@ set_fit(
125132
)
126133
)
127134

135+
set_encoding(
136+
model = "boost_tree",
137+
eng = "xgboost",
138+
mode = "classification",
139+
options = list(predictor_indicators = TRUE)
140+
)
141+
128142
set_pred(
129143
model = "boost_tree",
130144
eng = "xgboost",
@@ -221,6 +235,13 @@ set_fit(
221235
)
222236
)
223237

238+
set_encoding(
239+
model = "boost_tree",
240+
eng = "C5.0",
241+
mode = "classification",
242+
options = list(predictor_indicators = FALSE)
243+
)
244+
224245
set_pred(
225246
model = "boost_tree",
226247
eng = "C5.0",
@@ -344,6 +365,13 @@ set_fit(
344365
)
345366
)
346367

368+
set_encoding(
369+
model = "boost_tree",
370+
eng = "spark",
371+
mode = "regression",
372+
options = list(predictor_indicators = TRUE)
373+
)
374+
347375
set_fit(
348376
model = "boost_tree",
349377
eng = "spark",
@@ -357,6 +385,13 @@ set_fit(
357385
)
358386
)
359387

388+
set_encoding(
389+
model = "boost_tree",
390+
eng = "spark",
391+
mode = "classification",
392+
options = list(predictor_indicators = TRUE)
393+
)
394+
360395
set_pred(
361396
model = "boost_tree",
362397
eng = "spark",

R/convert_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
#' @importFrom stats .checkMFClasses .getXlevels delete.response
1616
#' @importFrom stats model.offset model.weights na.omit na.pass
1717

18-
convert_form_to_xy_fit <-function(
18+
convert_form_to_xy_fit <- function(
1919
formula,
2020
data,
2121
...,

R/decision_tree_data.R

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ set_fit(
4848
)
4949
)
5050

51+
set_encoding(
52+
model = "decision_tree",
53+
eng = "rpart",
54+
mode = "regression",
55+
options = list(predictor_indicators = FALSE)
56+
)
57+
5158
set_fit(
5259
model = "decision_tree",
5360
eng = "rpart",
@@ -60,6 +67,13 @@ set_fit(
6067
)
6168
)
6269

70+
set_encoding(
71+
model = "decision_tree",
72+
eng = "rpart",
73+
mode = "classification",
74+
options = list(predictor_indicators = FALSE)
75+
)
76+
6377
set_pred(
6478
model = "decision_tree",
6579
eng = "rpart",
@@ -158,6 +172,13 @@ set_fit(
158172
)
159173
)
160174

175+
set_encoding(
176+
model = "decision_tree",
177+
eng = "C5.0",
178+
mode = "classification",
179+
options = list(predictor_indicators = FALSE)
180+
)
181+
161182
set_pred(
162183
model = "decision_tree",
163184
eng = "C5.0",
@@ -211,7 +232,7 @@ set_pred(
211232

212233
set_model_engine("decision_tree", "classification", "spark")
213234
set_model_engine("decision_tree", "regression", "spark")
214-
set_dependency("decision_tree", "spark", "spark")
235+
set_dependency("decision_tree", "spark", "sparklyr")
215236

216237
set_model_arg(
217238
model = "decision_tree",
@@ -239,12 +260,19 @@ set_fit(
239260
interface = "formula",
240261
data = c(formula = "formula", data = "x"),
241262
protect = c("x", "formula"),
242-
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
263+
func = c(pkg = "sparklyr", fun = "ml_decision_tree_regressor"),
243264
defaults =
244265
list(seed = expr(sample.int(10 ^ 5, 1)))
245266
)
246267
)
247268

269+
set_encoding(
270+
model = "decision_tree",
271+
eng = "spark",
272+
mode = "regression",
273+
options = list(predictor_indicators = TRUE)
274+
)
275+
248276
set_fit(
249277
model = "decision_tree",
250278
eng = "spark",
@@ -259,6 +287,13 @@ set_fit(
259287
)
260288
)
261289

290+
set_encoding(
291+
model = "decision_tree",
292+
eng = "spark",
293+
mode = "classification",
294+
options = list(predictor_indicators = TRUE)
295+
)
296+
262297
set_pred(
263298
model = "decision_tree",
264299
eng = "spark",

R/fit.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ fit.model_spec <-
103103
eng_vals <- possible_engines(object)
104104
object$engine <- eng_vals[1]
105105
if (control$verbosity > 0) {
106-
rlang::warn("Engine set to `{object$engine}`.")
106+
rlang::warn(glue::glue("Engine set to `{object$engine}`."))
107107
}
108108
}
109109

R/fit_helpers.R

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,12 +103,17 @@ xy_xy <- function(object, env, control, target = "none", ...) {
103103
form_xy <- function(object, control, env,
104104
target = "none", ...) {
105105

106+
indicators <- get_encoding(class(object)[1]) %>%
107+
dplyr::filter(mode == object$mode,
108+
engine == object$engine) %>%
109+
dplyr::pull(predictor_indicators)
110+
106111
data_obj <- convert_form_to_xy_fit(
107112
formula = env$formula,
108113
data = env$data,
109114
...,
110-
composition = target
111-
# indicators
115+
composition = target,
116+
indicators = indicators
112117
)
113118
env$x <- data_obj$x
114119
env$y <- data_obj$y

0 commit comments

Comments
 (0)