Skip to content

Commit bf507af

Browse files
authored
Merge pull request #316 from tidymodels/alt-argument-names
Handling alternative data argument names and fixing calls
2 parents 7d32c02 + 9e580d5 commit bf507af

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+338
-139
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ export(predict_quantile.model_fit)
142142
export(predict_raw)
143143
export(predict_raw.model_fit)
144144
export(rand_forest)
145+
export(repair_call)
145146
export(rpart_train)
146147
export(set_args)
147148
export(set_dependency)

NEWS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,20 @@
44

55
* `tidyr` >= 1.0.0 is now required.
66

7+
* SVM models produced by `kernlab` now use the formula method. This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
8+
9+
* MARS models produced by `earth` now use the formula method.
10+
11+
* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accomodated. (#315)
12+
713
## New Features
814

915
* A new main argument was added to `boost_tree()` called `stop_iter` for early stopping. The `xgb_train()` function gained arguments for early stopping and a percentage of data to leave out for a validation set.
1016

17+
* If `fit()` is used and the underlying model uses a formula, the _actual_ formula is pass to the model (instead of a placeholder). This makes the model call better.
18+
19+
* A function named `repair_call()` was added. This can help change the underlying models `call` object to better reflect what they would have obtained if the model function had been used directly (instead of via `parsnip`). This is only useful when the user chooses a formula interface and the model uses a formula interface. It will also be of limited use when a recipes is used to construct the feature set in `workflows` or `tune`.
20+
1121
# parsnip 0.1.1
1222

1323
## New Features

R/aaa.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ utils::globalVariables(
6666
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
6767
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
6868
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
69-
"sub_neighbors", ".pred_class")
69+
"sub_neighbors", ".pred_class", "x", "y")
7070
)
7171

7272
# nocov end

R/aaa_models.R

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,35 @@ check_fit_info <- function(fit_obj) {
195195
if (is.null(fit_obj)) {
196196
rlang::abort("The `fit` module cannot be NULL.")
197197
}
198+
199+
# check required data elements
198200
exp_nms <- c("defaults", "func", "interface", "protect")
199-
if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) {
201+
has_req_nms <- exp_nms %in% names(fit_obj)
202+
203+
if (!all(has_req_nms)) {
200204
rlang::abort(
201205
glue::glue("The `fit` module should have elements: ",
202206
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))
203207
)
204208
}
205209

210+
# check optional data elements
211+
opt_nms <- c("data")
212+
other_nms <- setdiff(exp_nms, names(fit_obj))
213+
has_opt_nms <- other_nms %in% opt_nms
214+
if (any(!has_opt_nms)) {
215+
msg <- glue::glue("The `fit` module can only have optional elements: ",
216+
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))
217+
218+
rlang::abort(msg)
219+
}
220+
if (any(other_nms == "data")) {
221+
data_nms <- names(fit_obj$data)
222+
if (length(data_nms == 0) || any(data_nms == "")) {
223+
rlang::abort("All elements of the `data` argument vector must be named.")
224+
}
225+
}
226+
206227
check_interface_val(fit_obj$interface)
207228
check_func_val(fit_obj$func)
208229

R/arguments.R

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ check_eng_args <- function(args, obj, core_args) {
2424
if (length(common_args) > 0) {
2525
args <- args[!(names(args) %in% common_args)]
2626
common_args <- paste0(common_args, collapse = ", ")
27-
rlang::warn(glue::glue("The following arguments cannot be manually modified",
27+
rlang::warn(glue::glue("The following arguments cannot be manually modified ",
2828
"and were removed: {common_args}."))
2929
}
3030
args
@@ -113,3 +113,94 @@ eval_args <- function(spec, ...) {
113113
spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)
114114
spec
115115
}
116+
117+
# ------------------------------------------------------------------------------
118+
119+
# In some cases, a model function that we are calling has non-standard argument
120+
# names. For example, a function foo() that only has the x/y interface might
121+
# have a signature like `foo(X, Y)`.
122+
123+
# To deal with this, we allow for the `data` element of the model
124+
# as an option to specify these actual argument names
125+
#
126+
# value = list(
127+
# interface = "xy",
128+
# data = c(x = "X", y = "Y"),
129+
# protect = c("X", "Y"),
130+
# func = c(pkg = "bar", fun = "foo"),
131+
# defaults = list()
132+
# )
133+
134+
make_call <- function(fun, ns, args, ...) {
135+
# remove any null or placeholders (`missing_args`) that remain
136+
discard <-
137+
vapply(args, function(x)
138+
is_missing_arg(x) | is.null(x), logical(1))
139+
args <- args[!discard]
140+
141+
if (!is.null(ns) & !is.na(ns)) {
142+
out <- call2(fun, !!!args, .ns = ns)
143+
} else
144+
out <- call2(fun, !!!args)
145+
out
146+
}
147+
148+
149+
make_form_call <- function(object, env = NULL) {
150+
fit_args <- object$method$fit$args
151+
152+
# Get the arguments related to data:
153+
if (is.null(object$method$fit$data)) {
154+
data_args <- c(formula = "formula", data = "data")
155+
} else {
156+
data_args <- object$method$fit$data
157+
}
158+
159+
# add data arguments
160+
for (i in seq_along(data_args)) {
161+
fit_args[[ unname(data_args[i]) ]] <- sym(names(data_args)[i])
162+
}
163+
164+
# sub in actual formula
165+
fit_args[[ unname(data_args["formula"]) ]] <- env$formula
166+
167+
if (object$engine == "spark") {
168+
env$x <- env$data
169+
}
170+
171+
fit_call <- make_call(
172+
fun = object$method$fit$func["fun"],
173+
ns = object$method$fit$func["pkg"],
174+
fit_args
175+
)
176+
fit_call
177+
}
178+
179+
make_xy_call <- function(object, target) {
180+
fit_args <- object$method$fit$args
181+
182+
# Get the arguments related to data:
183+
if (is.null(object$method$fit$data)) {
184+
data_args <- c(x = "x", y = "y")
185+
} else {
186+
data_args <- object$method$fit$data
187+
}
188+
189+
object$method$fit$args[[ unname(data_args["y"]) ]] <- rlang::expr(y)
190+
object$method$fit$args[[ unname(data_args["x"]) ]] <-
191+
switch(
192+
target,
193+
none = rlang::expr(x),
194+
data.frame = rlang::expr(as.data.frame(x)),
195+
matrix = rlang::expr(as.matrix(x)),
196+
rlang::abort(glue::glue("Invalid data type target: {target}."))
197+
)
198+
199+
fit_call <- make_call(
200+
fun = object$method$fit$func["fun"],
201+
ns = object$method$fit$func["pkg"],
202+
object$method$fit$args
203+
)
204+
205+
fit_call
206+
}

R/boost_tree_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ set_fit(
337337
mode = "regression",
338338
value = list(
339339
interface = "formula",
340+
data = c(formula = "formula", data = "x"),
340341
protect = c("x", "formula", "type"),
341342
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
342343
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
@@ -349,6 +350,7 @@ set_fit(
349350
mode = "classification",
350351
value = list(
351352
interface = "formula",
353+
data = c(formula = "formula", data = "x"),
352354
protect = c("x", "formula", "type"),
353355
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
354356
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))

R/decision_tree_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ set_fit(
237237
mode = "regression",
238238
value = list(
239239
interface = "formula",
240+
data = c(formula = "formula", data = "x"),
240241
protect = c("x", "formula"),
241242
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
242243
defaults =
@@ -250,6 +251,7 @@ set_fit(
250251
mode = "classification",
251252
value = list(
252253
interface = "formula",
254+
data = c(formula = "formula", data = "x"),
253255
protect = c("x", "formula"),
254256
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
255257
defaults =

R/fit_helpers.R

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,7 @@ form_form <-
3030
# sub in arguments to actual syntax for corresponding engine
3131
object <- translate(object, engine = object$engine)
3232

33-
fit_args <- object$method$fit$args
34-
35-
if (is_spark(object)) {
36-
fit_args$x <- quote(x)
37-
env$x <- env$data
38-
} else {
39-
fit_args$data <- quote(data)
40-
}
41-
fit_args$formula <- quote(formula)
42-
43-
fit_call <- make_call(
44-
fun = object$method$fit$func["fun"],
45-
ns = object$method$fit$func["pkg"],
46-
fit_args
47-
)
33+
fit_call <- make_form_call(object, env = env)
4834

4935
res <- list(
5036
lvl = y_levels,
@@ -89,21 +75,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
8975
# sub in arguments to actual syntax for corresponding engine
9076
object <- translate(object, engine = object$engine)
9177

92-
object$method$fit$args[["y"]] <- quote(y)
93-
object$method$fit$args[["x"]] <-
94-
switch(
95-
target,
96-
none = quote(x),
97-
data.frame = quote(as.data.frame(x)),
98-
matrix = quote(as.matrix(x)),
99-
rlang::abort(glue::glue("Invalid data type target: {target}."))
100-
)
101-
102-
fit_call <- make_call(
103-
fun = object$method$fit$func["fun"],
104-
ns = object$method$fit$func["pkg"],
105-
object$method$fit$args
106-
)
78+
fit_call <- make_xy_call(object, target)
10779

10880
res <- list(lvl = levels(env$y), spec = object)
10981

R/linear_reg_data.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,7 @@ set_fit(
312312
mode = "regression",
313313
value = list(
314314
interface = "formula",
315+
data = c(formula = "formula", data = "x"),
315316
protect = c("x", "formula", "weight_col"),
316317
func = c(pkg = "sparklyr", fun = "ml_linear_regression"),
317318
defaults = list()

R/logistic_reg_data.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ set_fit(
236236
mode = "classification",
237237
value = list(
238238
interface = "formula",
239+
data = c(formula = "formula", data = "x"),
239240
protect = c("x", "formula", "weight_col"),
240241
func = c(pkg = "sparklyr", fun = "ml_logistic_regression"),
241242
defaults =

R/mars_data.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ set_fit(
4040
eng = "earth",
4141
mode = "regression",
4242
value = list(
43-
interface = "data.frame",
44-
protect = c("x", "y", "weights"),
43+
interface = "formula",
44+
protect = c("formula", "data", "weights"),
4545
func = c(pkg = "earth", fun = "earth"),
4646
defaults = list(keepxy = TRUE)
4747
)
@@ -52,8 +52,8 @@ set_fit(
5252
eng = "earth",
5353
mode = "classification",
5454
value = list(
55-
interface = "data.frame",
56-
protect = c("x", "y", "weights"),
55+
interface = "formula",
56+
protect = c("formula", "data", "weights"),
5757
func = c(pkg = "earth", fun = "earth"),
5858
defaults = list(keepxy = TRUE)
5959
)

R/misc.R

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,6 @@ convert_arg <- function(x) {
115115
x
116116
}
117117

118-
make_call <- function(fun, ns, args, ...) {
119-
120-
#args <- map(args, convert_arg)
121-
122-
# remove any null or placeholders (`missing_args`) that remain
123-
discard <-
124-
vapply(args, function(x)
125-
is_missing_arg(x) | is.null(x), logical(1))
126-
args <- args[!discard]
127-
128-
if (!is.null(ns) & !is.na(ns)) {
129-
out <- call2(fun, !!!args, .ns = ns)
130-
} else
131-
out <- call2(fun, !!!args)
132-
out
133-
}
134118

135119
levels_from_formula <- function(f, dat) {
136120
if (inherits(dat, "tbl_spark"))

R/multinom_reg_data.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ set_fit(
122122
mode = "classification",
123123
value = list(
124124
interface = "formula",
125+
data = c(formula = "formula", data = "x"),
125126
protect = c("x", "formula", "weight_col"),
126127
func = c(pkg = "sparklyr", fun = "ml_logistic_regression"),
127128
defaults = list(family = "multinomial")

R/rand_forest_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,7 @@ set_fit(
468468
mode = "classification",
469469
value = list(
470470
interface = "formula",
471+
data = c(formula = "formula", data = "x"),
471472
protect = c("x", "formula", "type"),
472473
func = c(pkg = "sparklyr", fun = "ml_random_forest"),
473474
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
@@ -480,6 +481,7 @@ set_fit(
480481
mode = "regression",
481482
value = list(
482483
interface = "formula",
484+
data = c(formula = "formula", data = "x"),
483485
protect = c("x", "formula", "type"),
484486
func = c(pkg = "sparklyr", fun = "ml_random_forest"),
485487
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))

R/repair_call.R

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
#' Repair a model call object
2+
#'
3+
#' When the user passes a formula to `fit()` _and_ the underyling model function
4+
#' uses a formula, the call object produced by `fit()` may not be usable by
5+
#' other functions. For example, some arguments may still be quosures and the
6+
#' `data` portion of the call will not correspond to the original data.
7+
#'
8+
#' `repair_call()` call can adjust the model objects call to be usable by other
9+
#' functions and methods.
10+
#' @param x A fitted `parsnip` model. An error will occur if the underlying model
11+
#' does not have a `call` element.
12+
#' @param data A data object that is relavant to the call. In most cases, this
13+
#' is the data frame that was given to `parsnip` for the model fit (i.e., the
14+
#' training set data). The name of this data object is inserted into the call.
15+
#' @return A modified `parsnip` fitted model.
16+
#' @examples
17+
#'
18+
#' fitted_model <-
19+
#' linear_reg() %>%
20+
#' set_engine("lm", model = TRUE) %>%
21+
#' fit(mpg ~ ., data = mtcars)
22+
#'
23+
#' # In this call, note that `data` is not `mtcars` and the `model = ~TRUE`
24+
#' # indicates that the `model` argument is an `rlang` quosure.
25+
#' fitted_model$fit$call
26+
#'
27+
#' # All better:
28+
#' repair_call(fitted_model, mtcars)$fit$call
29+
#' @export
30+
repair_call <- function(x, data) {
31+
cl <- match.call()
32+
if (!any(names(x$fit) == "call")) {
33+
rlang::abort("No `call` object to modify.")
34+
}
35+
if (rlang::is_missing(data)) {
36+
rlang::abort("Please supply a data object to `data`.")
37+
}
38+
fit_call <- x$fit$call
39+
needs_eval <- purrr::map_lgl(fit_call, rlang::is_quosure)
40+
if (any(needs_eval)) {
41+
eval_args <- names(needs_eval)[needs_eval]
42+
for(arg in eval_args) {
43+
fit_call[[arg]] <- rlang::eval_tidy(fit_call[[arg]])
44+
}
45+
}
46+
if (any(names(fit_call) == "data")) {
47+
fit_call$data <- cl$data
48+
}
49+
50+
x$fit$call <- fit_call
51+
x
52+
}

0 commit comments

Comments
 (0)