Skip to content

Commit 8af904b

Browse files
authored
Merge pull request #377 from tidymodels/check-arg-dimensions
Make arguments robust to values outside of data dimensions
2 parents 981bc68 + 79e39b3 commit 8af904b

18 files changed

+436
-169
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ export(make_classes)
132132
export(mars)
133133
export(maybe_data_frame)
134134
export(maybe_matrix)
135+
export(min_cols)
136+
export(min_rows)
135137
export(mlp)
136138
export(model_printer)
137139
export(multi_predict)

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44

55
* For three models (`glmnet`, `xgboost`, and `ranger`), enable sparse matrix use via `fit_xy()` (#373).
66

7+
* Some added protections were added for function arguments that are dependent on the data dimensions (e.g., `mtry`, `neighbors`, `min_n`, etc). (#184)
8+
9+
* Infrastructure was improved for running `parsnip` models in parallel using PSOCK clusters on Windows.
10+
711
# parsnip 0.1.3
812

913
* A `glance()` method for `model_fit` objects was added (#325)

R/arguments.R

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,66 @@ make_xy_call <- function(object, target) {
204204

205205
fit_call
206206
}
207+
208+
## -----------------------------------------------------------------------------
209+
#' Execution-time data dimension checks
210+
#'
211+
#' For some tuning parameters, the range of values depend on the data
212+
#' dimensions (e.g. `mtry`). Some packages will fail if the parameter values are
213+
#' outside of these ranges. Since the model might receive resampled versions of
214+
#' the data, these ranges can't be set prior to the point where the model is
215+
#' fit. These functions check the possible range of the data and adjust them
216+
#' if needed (with a warning).
217+
#'
218+
#' @param num_cols,num_rows The parameter value requested by the user.
219+
#' @param source A data frame for the data to be used in the fit. If the source
220+
#' is named "data", it is assumed that one column of the data corresponds to
221+
#' an outcome (and is subtracted off).
222+
#' @param offset A number subtracted off of the number of rows available in the
223+
#' data.
224+
#' @return An integer (and perhaps a warning).
225+
#' @examples
226+
227+
#' nearest_neighbor(neighbors= 100) %>%
228+
#' set_engine("kknn") %>%
229+
#' set_mode("regression") %>%
230+
#' translate()
231+
#'
232+
#' library(ranger)
233+
#' rand_forest(mtry = 2, min_n = 100, trees = 3) %>%
234+
#' set_engine("ranger") %>%
235+
#' set_mode("regression") %>%
236+
#' fit(mpg ~ ., data = mtcars)
237+
#' @export
238+
min_cols <- function(num_cols, source) {
239+
cl <- match.call()
240+
src_name <- rlang::expr_text(cl$source)
241+
if (cl$source == "data") {
242+
p <- ncol(source) - 1
243+
} else {
244+
p <- ncol(source)
245+
}
246+
if (num_cols > p) {
247+
msg <- paste0(num_cols, " columns were requested but there were ", p,
248+
" predictors in the data. ", p, " will be used.")
249+
rlang::warn(msg)
250+
num_cols <- p
251+
}
252+
253+
as.integer(num_cols)
254+
}
255+
256+
#' @export
257+
#' @rdname min_cols
258+
min_rows <- function(num_rows, source, offset = 0) {
259+
n <- nrow(source)
260+
261+
if (num_rows > n - offset) {
262+
msg <- paste0(num_rows, " samples were requested but there were ", n,
263+
" rows in the data. ", n - offset, " will be used.")
264+
rlang::warn(msg)
265+
num_rows <- n - offset
266+
}
267+
268+
as.integer(num_rows)
269+
}

R/boost_tree.R

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,10 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
221221
}
222222
x <- translate.default(x, engine, ...)
223223

224+
## -----------------------------------------------------------------------------
225+
226+
arg_vals <- x$method$fit$args
227+
224228
if (engine == "spark") {
225229
if (x$mode == "unknown") {
226230
rlang::abort(
@@ -230,9 +234,23 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
230234
)
231235
)
232236
} else {
233-
x$method$fit$args$type <- x$mode
237+
arg_vals$type <- x$mode
234238
}
235239
}
240+
241+
## -----------------------------------------------------------------------------
242+
# Protect some arguments based on data dimensions
243+
244+
# min_n parameters
245+
if (any(names(arg_vals) == "min_instances_per_node")) {
246+
arg_vals$min_instances_per_node <-
247+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x))
248+
}
249+
250+
## -----------------------------------------------------------------------------
251+
252+
x$method$fit$args <- arg_vals
253+
236254
x
237255
}
238256

@@ -242,14 +260,18 @@ check_args.boost_tree <- function(object) {
242260

243261
args <- lapply(object$args, rlang::eval_tidy)
244262

245-
if (is.numeric(args$trees) && args$trees < 0)
263+
if (is.numeric(args$trees) && args$trees < 0) {
246264
rlang::abort("`trees` should be >= 1.")
247-
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1))
265+
}
266+
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) {
248267
rlang::abort("`sample_size` should be within [0,1].")
249-
if (is.numeric(args$tree_depth) && args$tree_depth < 0)
268+
}
269+
if (is.numeric(args$tree_depth) && args$tree_depth < 0) {
250270
rlang::abort("`tree_depth` should be >= 1.")
251-
if (is.numeric(args$min_n) && args$min_n < 0)
271+
}
272+
if (is.numeric(args$min_n) && args$min_n < 0) {
252273
rlang::abort("`min_n` should be >= 1.")
274+
}
253275

254276
invisible(object)
255277
}
@@ -335,12 +357,19 @@ xgb_train <- function(
335357
colsample_bytree <- 1
336358
}
337359

360+
if (min_child_weight > n) {
361+
msg <- paste0(min_child_weight, " samples were requested but there were ",
362+
n, " rows in the data. ", n, " will be used.")
363+
rlang::warn(msg)
364+
min_child_weight <- min(min_child_weight, n)
365+
}
366+
338367
arg_list <- list(
339368
eta = eta,
340369
max_depth = max_depth,
341370
gamma = gamma,
342371
colsample_bytree = colsample_bytree,
343-
min_child_weight = min_child_weight,
372+
min_child_weight = min(min_child_weight, n),
344373
subsample = subsample
345374
)
346375

@@ -515,8 +544,21 @@ C5.0_train <-
515544
ctrl_args <- other_args[names(other_args) %in% c_names]
516545
fit_args <- other_args[names(other_args) %in% f_names]
517546

547+
n <- nrow(x)
548+
if (n == 0) {
549+
rlang::abort("There are zero rows in the predictor set.")
550+
}
551+
552+
518553
ctrl <- call2("C5.0Control", .ns = "C50")
554+
if (minCases > n) {
555+
msg <- paste0(minCases, " samples were requested but there were ",
556+
n, " rows in the data. ", n, " will be used.")
557+
rlang::warn(msg)
558+
minCases <- n
559+
}
519560
ctrl$minCases <- minCases
561+
520562
ctrl$sample <- sample
521563
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
522564

R/decision_tree.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,22 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
180180
}
181181
}
182182

183+
## -----------------------------------------------------------------------------
184+
# Protect some arguments based on data dimensions
185+
186+
if (any(names(arg_vals) == "minsplit")) {
187+
arg_vals$minsplit <-
188+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$minsplit), expr(data))
189+
}
190+
if (any(names(arg_vals) == "min_instances_per_node")) {
191+
arg_vals$min_instances_per_node <-
192+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x))
193+
}
194+
195+
## -----------------------------------------------------------------------------
196+
197+
x$method$fit$args <- arg_vals
198+
183199
x
184200
}
185201

R/nearest_neighbor.R

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,25 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
168168
}
169169
x <- translate.default(x, engine, ...)
170170

171+
arg_vals <- x$method$fit$args
172+
171173
if (engine == "kknn") {
172-
if (!any(names(x$method$fit$args) == "ks") ||
173-
is_missing_arg(x$method$fit$args$ks)) {
174-
x$method$fit$args$ks <- 5
174+
175+
if (!any(names(arg_vals) == "ks") || is_missing_arg(arg_vals$ks)) {
176+
arg_vals$ks <- 5
177+
}
178+
179+
## -----------------------------------------------------------------------------
180+
# Protect some arguments based on data dimensions
181+
182+
if (any(names(arg_vals) == "ks")) {
183+
arg_vals$ks <-
184+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$ks), expr(data), 5)
175185
}
176186
}
187+
188+
x$method$fit$args <- arg_vals
189+
177190
x
178191
}
179192

R/rand_forest.R

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
161161

162162
x <- translate.default(x, engine, ...)
163163

164+
## -----------------------------------------------------------------------------
165+
164166
# slightly cleaner code using
165167
arg_vals <- x$method$fit$args
166168

@@ -185,14 +187,47 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
185187

186188
# add checks to error trap or change things for this method
187189
if (engine == "ranger") {
188-
if (any(names(arg_vals) == "importance"))
189-
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance))))
190+
191+
if (any(names(arg_vals) == "importance")) {
192+
if (isTRUE(is.logical(quo_get_expr(arg_vals$importance)))) {
190193
rlang::abort("`importance` should be a character value. See ?ranger::ranger.")
194+
}
195+
}
191196
# unless otherwise specified, classification models are probability forests
192-
if (x$mode == "classification" && !any(names(arg_vals) == "probability"))
197+
if (x$mode == "classification" && !any(names(arg_vals) == "probability")) {
193198
arg_vals$probability <- TRUE
199+
}
200+
}
201+
202+
## -----------------------------------------------------------------------------
203+
# Protect some arguments based on data dimensions
204+
205+
if (any(names(arg_vals) == "mtry") & engine != "cforest") {
206+
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(x))
207+
}
208+
if (any(names(arg_vals) == "mtry") & engine == "cforest") {
209+
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(data))
210+
}
194211

212+
if (any(names(arg_vals) == "min.node.size")) {
213+
arg_vals$min.node.size <-
214+
rlang::call2("min_rows", arg_vals$min.node.size, expr(x))
215+
}
216+
if (any(names(arg_vals) == "minsplit" & engine == "cforest")) {
217+
arg_vals$minsplit <-
218+
rlang::call2("min_rows", arg_vals$minsplit, expr(data))
219+
}
220+
if (any(names(arg_vals) == "nodesize")) {
221+
arg_vals$nodesize <-
222+
rlang::call2("min_rows", arg_vals$nodesize, expr(x))
195223
}
224+
if (any(names(arg_vals) == "min_instances_per_node")) {
225+
arg_vals$min_instances_per_node <-
226+
rlang::call2("min_rows", arg_vals$min_instances_per_node, expr(x))
227+
}
228+
229+
## -----------------------------------------------------------------------------
230+
196231
x$method$fit$args <- arg_vals
197232

198233
x

man/min_cols.Rd

Lines changed: 44 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nearest_neighbor.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree.R

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,17 @@ test_that('bad input', {
151151
expect_error(translate(boost_tree(formula = y ~ x)))
152152
})
153153

154-
# ------------------------------------------------------------------------------
154+
155+
## -----------------------------------------------------------------------------
156+
157+
test_that('argument checks for data dimensions', {
158+
159+
spec <-
160+
boost_tree(mtry = 1000, min_n = 1000, trees = 5) %>%
161+
set_engine("spark") %>%
162+
set_mode("classification")
163+
164+
args <- translate(spec)$method$fit$args
165+
expect_equal(args$min_instances_per_node, expr(min_rows(1000, x)))
166+
})
167+

0 commit comments

Comments
 (0)