Skip to content

Commit b8dbf45

Browse files
committed
Add warnings when the argument is corrected
1 parent 636de7c commit b8dbf45

17 files changed

+195
-37
lines changed

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ export(make_classes)
132132
export(mars)
133133
export(maybe_data_frame)
134134
export(maybe_matrix)
135+
export(min_cols)
136+
export(min_rows)
135137
export(mlp)
136138
export(model_printer)
137139
export(multi_predict)

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
* For three models (`glmnet`, `xgboost`, and `ranger`), enable sparse matrix use via `fit_xy()` (#373).
66

7-
* Some added protections were added for function arguments that are dependent on the data dimensions (e.g., `mtry`, `neighbors`, `min_n`, etc).
7+
* Some added protections were added for function arguments that are dependent on the data dimensions (e.g., `mtry`, `neighbors`, `min_n`, etc). (#184)
8+
9+
* Infrastructure was improved for running `parsnip` models in parallel using PSOCK clusters on Windows.
810

911
# parsnip 0.1.3
1012

R/arguments.R

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,66 @@ make_xy_call <- function(object, target) {
204204

205205
fit_call
206206
}
207+
208+
## -----------------------------------------------------------------------------
209+
#' Execution-time data dimension checks
210+
#'
211+
#' For some tuning parameters, the range of values depend on the data
212+
#' dimensions (e.g. `mtry`). Some packages will fail if the parameter values are
213+
#' outside of these ranges. Since the model might receive resampled versions of
214+
#' the data, these ranges can't be set prior to the point where the model is
215+
#' fit. These functions check the possible range of the data and adjust them
216+
#' if needed (with a warning).
217+
#'
218+
#' @param num_cols,num_rows The parameter value requested by the user.
219+
#' @param source A data frame for the data to be used in the fit. If the source
220+
#' is named "data", it is assumed that one column of the data corresponds to
221+
#' an outcome (and is subtracted off).
222+
#' @param offset A number subtracted off of the number of rows available in the
223+
#' data.
224+
#' @return An integer (and perhaps a warning).
225+
#' @examples
226+
227+
#' nearest_neighbor(neighbors= 100) %>%
228+
#' set_engine("kknn") %>%
229+
#' set_mode("regression") %>%
230+
#' translate()
231+
#'
232+
#' library(ranger)
233+
#' rand_forest(mtry = 2, min_n = 100, trees = 3) %>%
234+
#' set_engine("ranger") %>%
235+
#' set_mode("regression") %>%
236+
#' fit(mpg ~ ., data = mtcars)
237+
#' @export
238+
min_cols <- function(num_cols, source) {
239+
cl <- match.call()
240+
src_name <- rlang::expr_text(cl$source)
241+
if (cl$source == "data") {
242+
p <- ncol(source) - 1
243+
} else {
244+
p <- ncol(source)
245+
}
246+
if (num_cols > p) {
247+
msg <- paste0(num_cols, " columns were requested but there were ", p,
248+
" predictors in the data. ", p, " will be used.")
249+
rlang::warn(msg)
250+
num_cols <- p
251+
}
252+
253+
as.integer(num_cols)
254+
}
255+
256+
#' @export
257+
#' @rdname min_cols
258+
min_rows <- function(num_rows, source, offset = 0) {
259+
n <- nrow(source)
260+
261+
if (num_rows > n - offset) {
262+
msg <- paste0(num_rows, " samples were requested but there were ", n,
263+
" rows in the data. ", n - offset, " will be used.")
264+
rlang::warn(msg)
265+
num_rows <- n - offset
266+
}
267+
268+
as.integer(num_rows)
269+
}

R/boost_tree.R

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
244244
# min_n parameters
245245
if (any(names(arg_vals) == "min_instances_per_node")) {
246246
arg_vals$min_instances_per_node <-
247-
rlang::call2("min", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(nrow(x)))
247+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x))
248248
}
249249

250250
## -----------------------------------------------------------------------------
@@ -357,6 +357,13 @@ xgb_train <- function(
357357
colsample_bytree <- 1
358358
}
359359

360+
if (min_child_weight > n) {
361+
msg <- paste0(min_child_weight, " samples were requested but there were ",
362+
n, " rows in the data. ", n, " will be used.")
363+
rlang::warn(msg)
364+
min_child_weight <- min(min_child_weight, n)
365+
}
366+
360367
arg_list <- list(
361368
eta = eta,
362369
max_depth = max_depth,
@@ -537,8 +544,21 @@ C5.0_train <-
537544
ctrl_args <- other_args[names(other_args) %in% c_names]
538545
fit_args <- other_args[names(other_args) %in% f_names]
539546

547+
n <- nrow(x)
548+
if (n == 0) {
549+
rlang::abort("There are zero rows in the predictor set.")
550+
}
551+
552+
540553
ctrl <- call2("C5.0Control", .ns = "C50")
541-
ctrl$minCases <- min(minCases, nrow(x))
554+
if (minCases > n) {
555+
msg <- paste0(minCases, " samples were requested but there were ",
556+
n, " rows in the data. ", n, " will be used.")
557+
rlang::warn(msg)
558+
minCases <- n
559+
}
560+
ctrl$minCases <- minCases
561+
542562
ctrl$sample <- sample
543563
ctrl <- rlang::call_modify(ctrl, !!!ctrl_args)
544564

R/decision_tree.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,11 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
185185

186186
if (any(names(arg_vals) == "minsplit")) {
187187
arg_vals$minsplit <-
188-
rlang::call2("min", rlang::eval_tidy(arg_vals$minsplit), expr(nrow(data)))
188+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$minsplit), expr(data))
189189
}
190190
if (any(names(arg_vals) == "min_instances_per_node")) {
191191
arg_vals$min_instances_per_node <-
192-
rlang::call2("min", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(nrow(x)))
192+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$min_instances_per_node), expr(x))
193193
}
194194

195195
## -----------------------------------------------------------------------------

R/nearest_neighbor.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ translate.nearest_neighbor <- function(x, engine = x$engine, ...) {
181181

182182
if (any(names(arg_vals) == "ks")) {
183183
arg_vals$ks <-
184-
rlang::call2("min", rlang::eval_tidy(arg_vals$ks), expr(nrow(data) - 5))
184+
rlang::call2("min_rows", rlang::eval_tidy(arg_vals$ks), expr(data), 5)
185185
}
186186
}
187187

R/rand_forest.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,20 +203,20 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
203203
# Protect some arguments based on data dimensions
204204

205205
if (any(names(arg_vals) == "mtry")) {
206-
arg_vals$mtry <- rlang::call2("min", arg_vals$mtry, expr(ncol(x)))
206+
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(x))
207207
}
208208

209209
if (any(names(arg_vals) == "min.node.size")) {
210210
arg_vals$min.node.size <-
211-
rlang::call2("min", arg_vals$min.node.size, expr(nrow(x)))
211+
rlang::call2("min_rows", arg_vals$min.node.size, expr(x))
212212
}
213213
if (any(names(arg_vals) == "nodesize")) {
214214
arg_vals$nodesize <-
215-
rlang::call2("min", arg_vals$nodesize, expr(nrow(x)))
215+
rlang::call2("min_rows", arg_vals$nodesize, expr(x))
216216
}
217217
if (any(names(arg_vals) == "min_instances_per_node")) {
218218
arg_vals$min_instances_per_node <-
219-
rlang::call2("min", arg_vals$min_instances_per_node, expr(nrow(x)))
219+
rlang::call2("min_rows", arg_vals$min_instances_per_node, expr(x))
220220
}
221221

222222
## -----------------------------------------------------------------------------

man/boost_tree.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/min_cols.Rd

Lines changed: 44 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/nearest_neighbor.Rd

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,6 @@ test_that('argument checks for data dimensions', {
162162
set_mode("classification")
163163

164164
args <- translate(spec)$method$fit$args
165-
expect_equal(args$min_instances_per_node, expr(min(1000, nrow(x))))
165+
expect_equal(args$min_instances_per_node, expr(min_rows(1000, x)))
166166
})
167167

tests/testthat/test_decision_tree.R

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ test_that('primary arguments', {
5959
formula = expr(missing_arg()),
6060
data = expr(missing_arg()),
6161
weights = expr(missing_arg()),
62-
minsplit = expr(min(15, nrow(data)))
62+
minsplit = expr(min_rows(15, data))
6363
)
6464
)
6565

@@ -163,8 +163,14 @@ test_that('argument checks for data dimensions', {
163163
set_engine("rpart") %>%
164164
set_mode("regression")
165165

166-
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins)
167-
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g)
166+
expect_warning(
167+
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins),
168+
"1000 samples were requested but there were 333 rows in the data. 333 will be used."
169+
)
170+
expect_warning(
171+
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g),
172+
"1000 samples were requested but there were 333 rows in the data. 333 will be used."
173+
)
168174

169175
expect_equal(f_fit$fit$control$minsplit, nrow(penguins))
170176
expect_equal(xy_fit$fit$control$minsplit, nrow(penguins))
@@ -175,6 +181,6 @@ test_that('argument checks for data dimensions', {
175181
set_mode("regression")
176182

177183
args <- translate(spec)$method$fit$args
178-
expect_equal(args$min_instances_per_node, rlang::expr(min(1000, nrow(x))))
184+
expect_equal(args$min_instances_per_node, rlang::expr(min_rows(1000, x)))
179185

180186
})

tests/testthat/test_nearest_neighbor.R

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ test_that('primary arguments', {
1818
expected = list(
1919
formula = expr(missing_arg()),
2020
data = expr(missing_arg()),
21-
ks = expr(min(5, nrow(data) - 5))
21+
ks = expr(min_rows(5, data, 5))
2222
)
2323
)
2424

@@ -30,7 +30,7 @@ test_that('primary arguments', {
3030
expected = list(
3131
formula = expr(missing_arg()),
3232
data = expr(missing_arg()),
33-
ks = expr(min(2, nrow(data) - 5))
33+
ks = expr(min_rows(2, data, 5))
3434
)
3535
)
3636

@@ -43,7 +43,7 @@ test_that('primary arguments', {
4343
formula = expr(missing_arg()),
4444
data = expr(missing_arg()),
4545
kernel = new_empty_quosure("triangular"),
46-
ks = expr(min(5, nrow(data) - 5))
46+
ks = expr(min_rows(5, data, 5))
4747
)
4848
)
4949

@@ -56,7 +56,7 @@ test_that('primary arguments', {
5656
formula = expr(missing_arg()),
5757
data = expr(missing_arg()),
5858
distance = new_empty_quosure(2),
59-
ks = expr(min(5, nrow(data) - 5))
59+
ks = expr(min_rows(5, data, 5))
6060
)
6161
)
6262

@@ -72,7 +72,7 @@ test_that('engine arguments', {
7272
formula = expr(missing_arg()),
7373
data = expr(missing_arg()),
7474
scale = new_empty_quosure(FALSE),
75-
ks = expr(min(5, nrow(data) - 5))
75+
ks = expr(min_rows(5, data, 5))
7676
)
7777
)
7878

tests/testthat/test_nearest_neighbor_kknn.R

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,15 @@ test_that('argument checks for data dimensions', {
207207
set_engine("kknn") %>%
208208
set_mode("regression")
209209

210-
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins)
211-
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g)
210+
expect_warning(
211+
f_fit <- spec %>% fit(body_mass_g ~ ., data = penguins),
212+
"1000 samples were requested but there were 333 rows in the data. 328 will be used."
213+
)
214+
215+
expect_warning(
216+
xy_fit <- spec %>% fit_xy(x = penguins[, -6], y = penguins$body_mass_g),
217+
"1000 samples were requested but there were 333 rows in the data. 328 will be used."
218+
)
212219

213220
expect_equal(f_fit$fit$best.parameters$k, nrow(penguins) - 5)
214221
expect_equal(xy_fit$fit$best.parameters$k, nrow(penguins) - 5)

0 commit comments

Comments
 (0)