Skip to content

Commit 3a3c134

Browse files
committed
Merge IT ALL
Merge branch 'master' into encoding-options # Conflicts: # R/linear_reg_data.R # R/svm_poly_data.R # R/svm_rbf_data.R # tests/testthat/test_svm_poly.R # tests/testthat/test_svm_rbf.R
2 parents a420749 + bf507af commit 3a3c134

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+366
-870
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ export(predict_quantile.model_fit)
143143
export(predict_raw)
144144
export(predict_raw.model_fit)
145145
export(rand_forest)
146+
export(repair_call)
146147
export(rpart_train)
147148
export(set_args)
148149
export(set_dependency)

NEWS.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,20 @@
44

55
* `tidyr` >= 1.0.0 is now required.
66

7+
* SVM models produced by `kernlab` now use the formula method. This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.
8+
9+
* MARS models produced by `earth` now use the formula method.
10+
11+
* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accomodated. (#315)
12+
713
## New Features
814

915
* A new main argument was added to `boost_tree()` called `stop_iter` for early stopping. The `xgb_train()` function gained arguments for early stopping and a percentage of data to leave out for a validation set.
1016

17+
* If `fit()` is used and the underlying model uses a formula, the _actual_ formula is pass to the model (instead of a placeholder). This makes the model call better.
18+
19+
* A function named `repair_call()` was added. This can help change the underlying models `call` object to better reflect what they would have obtained if the model function had been used directly (instead of via `parsnip`). This is only useful when the user chooses a formula interface and the model uses a formula interface. It will also be of limited use when a recipes is used to construct the feature set in `workflows` or `tune`.
20+
1121
# parsnip 0.1.1
1222

1323
## New Features

R/aaa.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ utils::globalVariables(
6666
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
6767
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
6868
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
69-
"sub_neighbors", ".pred_class")
69+
"sub_neighbors", ".pred_class", "x", "y")
7070
)
7171

7272
# nocov end

R/aaa_models.R

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,14 +195,35 @@ check_fit_info <- function(fit_obj) {
195195
if (is.null(fit_obj)) {
196196
rlang::abort("The `fit` module cannot be NULL.")
197197
}
198+
199+
# check required data elements
198200
exp_nms <- c("defaults", "func", "interface", "protect")
199-
if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) {
201+
has_req_nms <- exp_nms %in% names(fit_obj)
202+
203+
if (!all(has_req_nms)) {
200204
rlang::abort(
201205
glue::glue("The `fit` module should have elements: ",
202206
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))
203207
)
204208
}
205209

210+
# check optional data elements
211+
opt_nms <- c("data")
212+
other_nms <- setdiff(exp_nms, names(fit_obj))
213+
has_opt_nms <- other_nms %in% opt_nms
214+
if (any(!has_opt_nms)) {
215+
msg <- glue::glue("The `fit` module can only have optional elements: ",
216+
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))
217+
218+
rlang::abort(msg)
219+
}
220+
if (any(other_nms == "data")) {
221+
data_nms <- names(fit_obj$data)
222+
if (length(data_nms == 0) || any(data_nms == "")) {
223+
rlang::abort("All elements of the `data` argument vector must be named.")
224+
}
225+
}
226+
206227
check_interface_val(fit_obj$interface)
207228
check_func_val(fit_obj$func)
208229

R/arguments.R

Lines changed: 92 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ check_eng_args <- function(args, obj, core_args) {
2424
if (length(common_args) > 0) {
2525
args <- args[!(names(args) %in% common_args)]
2626
common_args <- paste0(common_args, collapse = ", ")
27-
rlang::warn(glue::glue("The following arguments cannot be manually modified",
27+
rlang::warn(glue::glue("The following arguments cannot be manually modified ",
2828
"and were removed: {common_args}."))
2929
}
3030
args
@@ -113,3 +113,94 @@ eval_args <- function(spec, ...) {
113113
spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)
114114
spec
115115
}
116+
117+
# ------------------------------------------------------------------------------
118+
119+
# In some cases, a model function that we are calling has non-standard argument
120+
# names. For example, a function foo() that only has the x/y interface might
121+
# have a signature like `foo(X, Y)`.
122+
123+
# To deal with this, we allow for the `data` element of the model
124+
# as an option to specify these actual argument names
125+
#
126+
# value = list(
127+
# interface = "xy",
128+
# data = c(x = "X", y = "Y"),
129+
# protect = c("X", "Y"),
130+
# func = c(pkg = "bar", fun = "foo"),
131+
# defaults = list()
132+
# )
133+
134+
make_call <- function(fun, ns, args, ...) {
135+
# remove any null or placeholders (`missing_args`) that remain
136+
discard <-
137+
vapply(args, function(x)
138+
is_missing_arg(x) | is.null(x), logical(1))
139+
args <- args[!discard]
140+
141+
if (!is.null(ns) & !is.na(ns)) {
142+
out <- call2(fun, !!!args, .ns = ns)
143+
} else
144+
out <- call2(fun, !!!args)
145+
out
146+
}
147+
148+
149+
make_form_call <- function(object, env = NULL) {
150+
fit_args <- object$method$fit$args
151+
152+
# Get the arguments related to data:
153+
if (is.null(object$method$fit$data)) {
154+
data_args <- c(formula = "formula", data = "data")
155+
} else {
156+
data_args <- object$method$fit$data
157+
}
158+
159+
# add data arguments
160+
for (i in seq_along(data_args)) {
161+
fit_args[[ unname(data_args[i]) ]] <- sym(names(data_args)[i])
162+
}
163+
164+
# sub in actual formula
165+
fit_args[[ unname(data_args["formula"]) ]] <- env$formula
166+
167+
if (object$engine == "spark") {
168+
env$x <- env$data
169+
}
170+
171+
fit_call <- make_call(
172+
fun = object$method$fit$func["fun"],
173+
ns = object$method$fit$func["pkg"],
174+
fit_args
175+
)
176+
fit_call
177+
}
178+
179+
make_xy_call <- function(object, target) {
180+
fit_args <- object$method$fit$args
181+
182+
# Get the arguments related to data:
183+
if (is.null(object$method$fit$data)) {
184+
data_args <- c(x = "x", y = "y")
185+
} else {
186+
data_args <- object$method$fit$data
187+
}
188+
189+
object$method$fit$args[[ unname(data_args["y"]) ]] <- rlang::expr(y)
190+
object$method$fit$args[[ unname(data_args["x"]) ]] <-
191+
switch(
192+
target,
193+
none = rlang::expr(x),
194+
data.frame = rlang::expr(as.data.frame(x)),
195+
matrix = rlang::expr(as.matrix(x)),
196+
rlang::abort(glue::glue("Invalid data type target: {target}."))
197+
)
198+
199+
fit_call <- make_call(
200+
fun = object$method$fit$func["fun"],
201+
ns = object$method$fit$func["pkg"],
202+
object$method$fit$args
203+
)
204+
205+
fit_call
206+
}

R/boost_tree_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,7 @@ set_fit(
358358
mode = "regression",
359359
value = list(
360360
interface = "formula",
361+
data = c(formula = "formula", data = "x"),
361362
protect = c("x", "formula", "type"),
362363
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
363364
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
@@ -377,6 +378,7 @@ set_fit(
377378
mode = "classification",
378379
value = list(
379380
interface = "formula",
381+
data = c(formula = "formula", data = "x"),
380382
protect = c("x", "formula", "type"),
381383
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
382384
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))

R/decision_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ print.decision_tree <- function(x, ...) {
102102

103103
#' @export
104104
#' @inheritParams update.boost_tree
105-
#' @param object A random forest model specification.
105+
#' @param object A decision tree model specification.
106106
#' @examples
107107
#' model <- decision_tree(cost_complexity = 10, min_n = 3)
108108
#' model

R/decision_tree_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ set_fit(
258258
mode = "regression",
259259
value = list(
260260
interface = "formula",
261+
data = c(formula = "formula", data = "x"),
261262
protect = c("x", "formula"),
262263
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
263264
defaults =
@@ -278,6 +279,7 @@ set_fit(
278279
mode = "classification",
279280
value = list(
280281
interface = "formula",
282+
data = c(formula = "formula", data = "x"),
281283
protect = c("x", "formula"),
282284
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
283285
defaults =

R/fit_helpers.R

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,7 @@ form_form <-
3030
# sub in arguments to actual syntax for corresponding engine
3131
object <- translate(object, engine = object$engine)
3232

33-
fit_args <- object$method$fit$args
34-
35-
if (is_spark(object)) {
36-
fit_args$x <- quote(x)
37-
env$x <- env$data
38-
} else {
39-
fit_args$data <- quote(data)
40-
}
41-
fit_args$formula <- quote(formula)
42-
43-
fit_call <- make_call(
44-
fun = object$method$fit$func["fun"],
45-
ns = object$method$fit$func["pkg"],
46-
fit_args
47-
)
33+
fit_call <- make_form_call(object, env = env)
4834

4935
res <- list(
5036
lvl = y_levels,
@@ -89,21 +75,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
8975
# sub in arguments to actual syntax for corresponding engine
9076
object <- translate(object, engine = object$engine)
9177

92-
object$method$fit$args[["y"]] <- quote(y)
93-
object$method$fit$args[["x"]] <-
94-
switch(
95-
target,
96-
none = quote(x),
97-
data.frame = quote(as.data.frame(x)),
98-
matrix = quote(as.matrix(x)),
99-
rlang::abort(glue::glue("Invalid data type target: {target}."))
100-
)
101-
102-
fit_call <- make_call(
103-
fun = object$method$fit$func["fun"],
104-
ns = object$method$fit$func["pkg"],
105-
object$method$fit$args
106-
)
78+
fit_call <- make_xy_call(object, target)
10779

10880
res <- list(lvl = levels(env$y), spec = object)
10981

R/linear_reg_data.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,7 @@ set_fit(
314314
mode = "regression",
315315
value = list(
316316
interface = "formula",
317+
data = c(formula = "formula", data = "x"),
317318
protect = c("x", "formula", "weight_col"),
318319
func = c(pkg = "sparklyr", fun = "ml_linear_regression"),
319320
defaults = list()

R/logistic_reg_data.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ set_fit(
249249
mode = "classification",
250250
value = list(
251251
interface = "formula",
252+
data = c(formula = "formula", data = "x"),
252253
protect = c("x", "formula", "weight_col"),
253254
func = c(pkg = "sparklyr", fun = "ml_logistic_regression"),
254255
defaults =

R/mars_data.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ set_fit(
4040
eng = "earth",
4141
mode = "regression",
4242
value = list(
43-
interface = "data.frame",
44-
protect = c("x", "y", "weights"),
43+
interface = "formula",
44+
protect = c("formula", "data", "weights"),
4545
func = c(pkg = "earth", fun = "earth"),
4646
defaults = list(keepxy = TRUE)
4747
)
@@ -59,8 +59,8 @@ set_fit(
5959
eng = "earth",
6060
mode = "classification",
6161
value = list(
62-
interface = "data.frame",
63-
protect = c("x", "y", "weights"),
62+
interface = "formula",
63+
protect = c("formula", "data", "weights"),
6464
func = c(pkg = "earth", fun = "earth"),
6565
defaults = list(keepxy = TRUE)
6666
)

R/misc.R

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -115,22 +115,6 @@ convert_arg <- function(x) {
115115
x
116116
}
117117

118-
make_call <- function(fun, ns, args, ...) {
119-
120-
#args <- map(args, convert_arg)
121-
122-
# remove any null or placeholders (`missing_args`) that remain
123-
discard <-
124-
vapply(args, function(x)
125-
is_missing_arg(x) | is.null(x), logical(1))
126-
args <- args[!discard]
127-
128-
if (!is.null(ns) & !is.na(ns)) {
129-
out <- call2(fun, !!!args, .ns = ns)
130-
} else
131-
out <- call2(fun, !!!args)
132-
out
133-
}
134118

135119
levels_from_formula <- function(f, dat) {
136120
if (inherits(dat, "tbl_spark"))

R/mlp.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
#' @param dropout A number between 0 (inclusive) and 1 denoting the proportion
3838
#' of model parameters randomly set to zero during model training.
3939
#' @param epochs An integer for the number of training iterations.
40-
#' @param activation A single character strong denoting the type of relationship
40+
#' @param activation A single character string denoting the type of relationship
4141
#' between the original predictors and the hidden unit layer. The activation
4242
#' function between the hidden and output layers is automatically set to either
4343
#' "linear" or "softmax" depending on the type of outcome. Possible values are:
@@ -105,7 +105,7 @@ print.mlp <- function(x, ...) {
105105
#'
106106
#' @export
107107
#' @inheritParams update.boost_tree
108-
#' @param object A random forest model specification.
108+
#' @param object A multilayer perceptron model specification.
109109
#' @examples
110110
#' model <- mlp(hidden_units = 10, dropout = 0.30)
111111
#' model

R/multinom_reg_data.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ set_fit(
128128
mode = "classification",
129129
value = list(
130130
interface = "formula",
131+
data = c(formula = "formula", data = "x"),
131132
protect = c("x", "formula", "weight_col"),
132133
func = c(pkg = "sparklyr", fun = "ml_logistic_regression"),
133134
defaults = list(family = "multinomial")

R/predict.R

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ format_num <- function(x) {
191191
names(x) <- paste0(".pred_", names(x))
192192
}
193193
} else {
194-
x <- tibble(.pred = x)
194+
x <- tibble(.pred = unname(x))
195195
}
196196

197197
x
@@ -201,14 +201,15 @@ format_class <- function(x) {
201201
if (inherits(x, "tbl_spark"))
202202
return(x)
203203

204-
tibble(.pred_class = x)
204+
tibble(.pred_class = unname(x))
205205
}
206206

207207
format_classprobs <- function(x) {
208208
if (!any(grepl("^\\.pred_", names(x)))) {
209209
names(x) <- paste0(".pred_", names(x))
210210
}
211211
x <- as_tibble(x)
212+
x <- purrr::map_dfr(x, rlang::set_names, NULL)
212213
x
213214
}
214215

R/rand_forest_data.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,7 @@ set_fit(
496496
mode = "classification",
497497
value = list(
498498
interface = "formula",
499+
data = c(formula = "formula", data = "x"),
499500
protect = c("x", "formula", "type"),
500501
func = c(pkg = "sparklyr", fun = "ml_random_forest"),
501502
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
@@ -515,6 +516,7 @@ set_fit(
515516
mode = "regression",
516517
value = list(
517518
interface = "formula",
519+
data = c(formula = "formula", data = "x"),
518520
protect = c("x", "formula", "type"),
519521
func = c(pkg = "sparklyr", fun = "ml_random_forest"),
520522
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))

0 commit comments

Comments
 (0)