Skip to content

Handling alternative data argument names and fixing calls #316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ export(predict_quantile.model_fit)
export(predict_raw)
export(predict_raw.model_fit)
export(rand_forest)
export(repair_call)
export(rpart_train)
export(set_args)
export(set_dependency)
Expand Down
10 changes: 10 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,20 @@

* `tidyr` >= 1.0.0 is now required.

* SVM models produced by `kernlab` now use the formula method. This change was due to how `ksvm()` made indicator variables for factor predictors (with one-hot encodings). Since the ordinary formula method did not do this, the data are passed as-is to `ksvm()` so that the results are closer to what one would get if `ksmv()` were called directly.

* MARS models produced by `earth` now use the formula method.

* Under-the-hood changes were made so that non-standard data arguments in the modeling packages can be accomodated. (#315)

## New Features

* A new main argument was added to `boost_tree()` called `stop_iter` for early stopping. The `xgb_train()` function gained arguments for early stopping and a percentage of data to leave out for a validation set.

* If `fit()` is used and the underlying model uses a formula, the _actual_ formula is pass to the model (instead of a placeholder). This makes the model call better.

* A function named `repair_call()` was added. This can help change the underlying models `call` object to better reflect what they would have obtained if the model function had been used directly (instead of via `parsnip`). This is only useful when the user chooses a formula interface and the model uses a formula interface. It will also be of limited use when a recipes is used to construct the feature set in `workflows` or `tune`.

# parsnip 0.1.1

## New Features
Expand Down
2 changes: 1 addition & 1 deletion R/aaa.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ utils::globalVariables(
'lab', 'original', 'predicted_label', 'prediction', 'value', 'type',
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
"sub_neighbors", ".pred_class")
"sub_neighbors", ".pred_class", "x", "y")
)

# nocov end
23 changes: 22 additions & 1 deletion R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,14 +195,35 @@ check_fit_info <- function(fit_obj) {
if (is.null(fit_obj)) {
rlang::abort("The `fit` module cannot be NULL.")
}

# check required data elements
exp_nms <- c("defaults", "func", "interface", "protect")
if (!isTRUE(all.equal(sort(names(fit_obj)), exp_nms))) {
has_req_nms <- exp_nms %in% names(fit_obj)

if (!all(has_req_nms)) {
rlang::abort(
glue::glue("The `fit` module should have elements: ",
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))
)
}

# check optional data elements
opt_nms <- c("data")
other_nms <- setdiff(exp_nms, names(fit_obj))
has_opt_nms <- other_nms %in% opt_nms
if (any(!has_opt_nms)) {
msg <- glue::glue("The `fit` module can only have optional elements: ",
glue::glue_collapse(glue::glue("`{exp_nms}`"), sep = ", "))

rlang::abort(msg)
}
if (any(other_nms == "data")) {
data_nms <- names(fit_obj$data)
if (length(data_nms == 0) || any(data_nms == "")) {
rlang::abort("All elements of the `data` argument vector must be named.")
}
}

check_interface_val(fit_obj$interface)
check_func_val(fit_obj$func)

Expand Down
93 changes: 92 additions & 1 deletion R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ check_eng_args <- function(args, obj, core_args) {
if (length(common_args) > 0) {
args <- args[!(names(args) %in% common_args)]
common_args <- paste0(common_args, collapse = ", ")
rlang::warn(glue::glue("The following arguments cannot be manually modified",
rlang::warn(glue::glue("The following arguments cannot be manually modified ",
"and were removed: {common_args}."))
}
args
Expand Down Expand Up @@ -113,3 +113,94 @@ eval_args <- function(spec, ...) {
spec$eng_args <- purrr::map(spec$eng_args, maybe_eval)
spec
}

# ------------------------------------------------------------------------------

# In some cases, a model function that we are calling has non-standard argument
# names. For example, a function foo() that only has the x/y interface might
# have a signature like `foo(X, Y)`.

# To deal with this, we allow for the `data` element of the model
# as an option to specify these actual argument names
#
# value = list(
# interface = "xy",
# data = c(x = "X", y = "Y"),
# protect = c("X", "Y"),
# func = c(pkg = "bar", fun = "foo"),
# defaults = list()
# )

make_call <- function(fun, ns, args, ...) {
# remove any null or placeholders (`missing_args`) that remain
discard <-
vapply(args, function(x)
is_missing_arg(x) | is.null(x), logical(1))
args <- args[!discard]

if (!is.null(ns) & !is.na(ns)) {
out <- call2(fun, !!!args, .ns = ns)
} else
out <- call2(fun, !!!args)
out
}


make_form_call <- function(object, env = NULL) {
fit_args <- object$method$fit$args

# Get the arguments related to data:
if (is.null(object$method$fit$data)) {
data_args <- c(formula = "formula", data = "data")
} else {
data_args <- object$method$fit$data
}

# add data arguments
for (i in seq_along(data_args)) {
fit_args[[ unname(data_args[i]) ]] <- sym(names(data_args)[i])
}

# sub in actual formula
fit_args[[ unname(data_args["formula"]) ]] <- env$formula

if (object$engine == "spark") {
env$x <- env$data
}

fit_call <- make_call(
fun = object$method$fit$func["fun"],
ns = object$method$fit$func["pkg"],
fit_args
)
fit_call
}

make_xy_call <- function(object, target) {
fit_args <- object$method$fit$args

# Get the arguments related to data:
if (is.null(object$method$fit$data)) {
data_args <- c(x = "x", y = "y")
} else {
data_args <- object$method$fit$data
}

object$method$fit$args[[ unname(data_args["y"]) ]] <- rlang::expr(y)
object$method$fit$args[[ unname(data_args["x"]) ]] <-
switch(
target,
none = rlang::expr(x),
data.frame = rlang::expr(as.data.frame(x)),
matrix = rlang::expr(as.matrix(x)),
rlang::abort(glue::glue("Invalid data type target: {target}."))
)

fit_call <- make_call(
fun = object$method$fit$func["fun"],
ns = object$method$fit$func["pkg"],
object$method$fit$args
)

fit_call
}
2 changes: 2 additions & 0 deletions R/boost_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ set_fit(
mode = "regression",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula", "type"),
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
Expand All @@ -349,6 +350,7 @@ set_fit(
mode = "classification",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula", "type"),
func = c(pkg = "sparklyr", fun = "ml_gradient_boosted_trees"),
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
Expand Down
2 changes: 2 additions & 0 deletions R/decision_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ set_fit(
mode = "regression",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula"),
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
defaults =
Expand All @@ -250,6 +251,7 @@ set_fit(
mode = "classification",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula"),
func = c(pkg = "sparklyr", fun = "ml_decision_tree_classifier"),
defaults =
Expand Down
32 changes: 2 additions & 30 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,7 @@ form_form <-
# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)

fit_args <- object$method$fit$args

if (is_spark(object)) {
fit_args$x <- quote(x)
env$x <- env$data
} else {
fit_args$data <- quote(data)
}
fit_args$formula <- quote(formula)

fit_call <- make_call(
fun = object$method$fit$func["fun"],
ns = object$method$fit$func["pkg"],
fit_args
)
fit_call <- make_form_call(object, env = env)

res <- list(
lvl = y_levels,
Expand Down Expand Up @@ -89,21 +75,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)

object$method$fit$args[["y"]] <- quote(y)
object$method$fit$args[["x"]] <-
switch(
target,
none = quote(x),
data.frame = quote(as.data.frame(x)),
matrix = quote(as.matrix(x)),
rlang::abort(glue::glue("Invalid data type target: {target}."))
)

fit_call <- make_call(
fun = object$method$fit$func["fun"],
ns = object$method$fit$func["pkg"],
object$method$fit$args
)
fit_call <- make_xy_call(object, target)

res <- list(lvl = levels(env$y), spec = object)

Expand Down
1 change: 1 addition & 0 deletions R/linear_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ set_fit(
mode = "regression",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula", "weight_col"),
func = c(pkg = "sparklyr", fun = "ml_linear_regression"),
defaults = list()
Expand Down
1 change: 1 addition & 0 deletions R/logistic_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ set_fit(
mode = "classification",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula", "weight_col"),
func = c(pkg = "sparklyr", fun = "ml_logistic_regression"),
defaults =
Expand Down
8 changes: 4 additions & 4 deletions R/mars_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ set_fit(
eng = "earth",
mode = "regression",
value = list(
interface = "data.frame",
protect = c("x", "y", "weights"),
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "earth", fun = "earth"),
defaults = list(keepxy = TRUE)
)
Expand All @@ -52,8 +52,8 @@ set_fit(
eng = "earth",
mode = "classification",
value = list(
interface = "data.frame",
protect = c("x", "y", "weights"),
interface = "formula",
protect = c("formula", "data", "weights"),
func = c(pkg = "earth", fun = "earth"),
defaults = list(keepxy = TRUE)
)
Expand Down
16 changes: 0 additions & 16 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,22 +115,6 @@ convert_arg <- function(x) {
x
}

make_call <- function(fun, ns, args, ...) {

#args <- map(args, convert_arg)

# remove any null or placeholders (`missing_args`) that remain
discard <-
vapply(args, function(x)
is_missing_arg(x) | is.null(x), logical(1))
args <- args[!discard]

if (!is.null(ns) & !is.na(ns)) {
out <- call2(fun, !!!args, .ns = ns)
} else
out <- call2(fun, !!!args)
out
}

levels_from_formula <- function(f, dat) {
if (inherits(dat, "tbl_spark"))
Expand Down
1 change: 1 addition & 0 deletions R/multinom_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ set_fit(
mode = "classification",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula", "weight_col"),
func = c(pkg = "sparklyr", fun = "ml_logistic_regression"),
defaults = list(family = "multinomial")
Expand Down
2 changes: 2 additions & 0 deletions R/rand_forest_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,7 @@ set_fit(
mode = "classification",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula", "type"),
func = c(pkg = "sparklyr", fun = "ml_random_forest"),
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
Expand All @@ -480,6 +481,7 @@ set_fit(
mode = "regression",
value = list(
interface = "formula",
data = c(formula = "formula", data = "x"),
protect = c("x", "formula", "type"),
func = c(pkg = "sparklyr", fun = "ml_random_forest"),
defaults = list(seed = expr(sample.int(10 ^ 5, 1)))
Expand Down
52 changes: 52 additions & 0 deletions R/repair_call.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#' Repair a model call object
#'
#' When the user passes a formula to `fit()` _and_ the underyling model function
#' uses a formula, the call object produced by `fit()` may not be usable by
#' other functions. For example, some arguments may still be quosures and the
#' `data` portion of the call will not correspond to the original data.
#'
#' `repair_call()` call can adjust the model objects call to be usable by other
#' functions and methods.
#' @param x A fitted `parsnip` model. An error will occur if the underlying model
#' does not have a `call` element.
#' @param data A data object that is relavant to the call. In most cases, this
#' is the data frame that was given to `parsnip` for the model fit (i.e., the
#' training set data). The name of this data object is inserted into the call.
#' @return A modified `parsnip` fitted model.
#' @examples
#'
#' fitted_model <-
#' linear_reg() %>%
#' set_engine("lm", model = TRUE) %>%
#' fit(mpg ~ ., data = mtcars)
#'
#' # In this call, note that `data` is not `mtcars` and the `model = ~TRUE`
#' # indicates that the `model` argument is an `rlang` quosure.
#' fitted_model$fit$call
#'
#' # All better:
#' repair_call(fitted_model, mtcars)$fit$call
#' @export
repair_call <- function(x, data) {
cl <- match.call()
if (!any(names(x$fit) == "call")) {
rlang::abort("No `call` object to modify.")
}
if (rlang::is_missing(data)) {
rlang::abort("Please supply a data object to `data`.")
}
fit_call <- x$fit$call
needs_eval <- purrr::map_lgl(fit_call, rlang::is_quosure)
if (any(needs_eval)) {
eval_args <- names(needs_eval)[needs_eval]
for(arg in eval_args) {
fit_call[[arg]] <- rlang::eval_tidy(fit_call[[arg]])
}
}
if (any(names(fit_call) == "data")) {
fit_call$data <- cl$data
}

x$fit$call <- fit_call
x
}
Loading