Skip to content

Commit 79e39b3

Browse files
committed
setting up for future party package engines
1 parent b8dbf45 commit 79e39b3

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

R/rand_forest.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,14 +202,21 @@ translate.rand_forest <- function(x, engine = x$engine, ...) {
202202
## -----------------------------------------------------------------------------
203203
# Protect some arguments based on data dimensions
204204

205-
if (any(names(arg_vals) == "mtry")) {
205+
if (any(names(arg_vals) == "mtry") & engine != "cforest") {
206206
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(x))
207207
}
208+
if (any(names(arg_vals) == "mtry") & engine == "cforest") {
209+
arg_vals$mtry <- rlang::call2("min_cols", arg_vals$mtry, expr(data))
210+
}
208211

209212
if (any(names(arg_vals) == "min.node.size")) {
210213
arg_vals$min.node.size <-
211214
rlang::call2("min_rows", arg_vals$min.node.size, expr(x))
212215
}
216+
if (any(names(arg_vals) == "minsplit" & engine == "cforest")) {
217+
arg_vals$minsplit <-
218+
rlang::call2("min_rows", arg_vals$minsplit, expr(data))
219+
}
213220
if (any(names(arg_vals) == "nodesize")) {
214221
arg_vals$nodesize <-
215222
rlang::call2("min_rows", arg_vals$nodesize, expr(x))

0 commit comments

Comments
 (0)