Skip to content

Commit 5e79150

Browse files
authored
Better mtry check for formulas (#700)
* move dim checks for mtry and min_n to wrapper function * better assessments of data dims * export function and updates for multivariate cases * pkgdown entry
1 parent 993d038 commit 5e79150

File tree

6 files changed

+76
-8
lines changed

6 files changed

+76
-8
lines changed

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ export(make_classes)
219219
export(make_engine_list)
220220
export(make_seealso_list)
221221
export(mars)
222+
export(max_mtry_formula)
222223
export(maybe_data_frame)
223224
export(maybe_matrix)
224225
export(min_cols)

R/partykit.R

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ cforest_train <-
9090
force(mtry)
9191
opts <- rlang::list2(...)
9292

93+
mtry <- max_mtry_formula(mtry, formula, data)
94+
minsplit <- min(minsplit, nrow(data))
95+
9396
if (any(names(opts) == "control")) {
9497
opts$control$minsplit <- minsplit
9598
opts$control$maxdepth <- maxdepth
@@ -124,3 +127,24 @@ cforest_train <-
124127
)
125128
rlang::eval_tidy(forest_call)
126129
}
130+
131+
# ------------------------------------------------------------------------------
132+
133+
#' Determine largest value of mtry from formula.
134+
#' This function potentially caps the value of `mtry` based on a formula and
135+
#' data set. This is a safe approach for survival and/or multivariate models.
136+
#' @param mtry An initial value of `mtry` (which may be too large).
137+
#' @param formula A model formula.
138+
#' @param data The training set (data frame).
139+
#' @return A value for `mtry`.
140+
#' @examples
141+
#' # should be 9
142+
#' max_mtry_formula(200, cbind(wt, mpg) ~ ., data = mtcars)
143+
#' @export
144+
max_mtry_formula <- function(mtry, formula, data) {
145+
preds <- stats::model.frame(formula, head(data))
146+
trms <- attr(preds, "terms")
147+
p <- ncol(attr(trms, "factors"))
148+
149+
max(min(mtry, p), 1L)
150+
}

R/rand_forest.R

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,21 +163,15 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
163163
## -----------------------------------------------------------------------------
164164
# Protect some arguments based on data dimensions
165165

166-
if (any(names(arg_vals) == "mtry") & engine != "cforest") {
166+
if (any(names(arg_vals) == "mtry") & engine != "partykit") {
167167
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(x))
168168
}
169-
if (any(names(arg_vals) == "mtry") & engine == "cforest") {
170-
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(data))
171-
}
172169

173170
if (any(names(arg_vals) == "min.node.size")) {
174171
arg_vals$min.node.size <-
175172
rlang::call2("min_rows", arg_vals$min.node.size, expr(x))
176173
}
177-
if (any(names(arg_vals) == "minsplit" & engine == "cforest")) {
178-
arg_vals$minsplit <-
179-
rlang::call2("min_rows", arg_vals$minsplit, expr(data))
180-
}
174+
181175
if (any(names(arg_vals) == "nodesize")) {
182176
arg_vals$nodesize <-
183177
rlang::call2("min_rows", arg_vals$nodesize, expr(x))

_pkgdown.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ reference:
9595
- set_new_model
9696
- maybe_matrix
9797
- min_cols
98+
- max_mtry_formula
9899
- required_pkgs
99100
- required_pkgs.model_spec
100101
- req_pkgs

man/max_mtry_formula.Rd

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_misc.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,23 @@ test_that('control class', {
8383
)
8484
})
8585

86+
# ------------------------------------------------------------------------------
87+
88+
test_that('correct mtry', {
89+
skip_if_not_installed("modeldata")
90+
data(ames, package = "modeldata")
91+
f_1 <- Sale_Price ~ Longitude + Latitude + Year_Built
92+
f_2 <- Sale_Price ~ .
93+
f_3 <- cbind(wt, mpg) ~ .
94+
95+
expect_equal(max_mtry_formula(2, f_1, ames), 2)
96+
expect_equal(max_mtry_formula(5, f_1, ames), 3)
97+
expect_equal(max_mtry_formula(0, f_1, ames), 1)
98+
99+
expect_equal(max_mtry_formula(2000, f_2, ames), ncol(ames) - 1)
100+
expect_equal(max_mtry_formula(2, f_2, ames), 2)
101+
102+
expect_equal(max_mtry_formula(200, f_3, data = mtcars), ncol(mtcars) - 2)
103+
104+
})
86105

0 commit comments

Comments
 (0)