Skip to content

Commit 0cb413c

Browse files
committed
changes for sparse matrices with ranger
1 parent 99950ec commit 0cb413c

File tree

6 files changed

+53
-45
lines changed

6 files changed

+53
-45
lines changed

R/boost_tree.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ xgb_train <- function(
375375
#' @importFrom stats binomial
376376
xgb_pred <- function(object, newdata, ...) {
377377
if (!inherits(newdata, "xgb.DMatrix")) {
378-
newdata <- as_matrix(newdata)
378+
newdata <- maybe_matrix(newdata)
379379
newdata <- xgboost::xgb.DMatrix(data = newdata, missing = NA)
380380
}
381381

R/rand_forest_data.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,8 @@ set_fit(
120120
eng = "ranger",
121121
mode = "classification",
122122
value = list(
123-
interface = "formula",
124-
protect = c("formula", "data", "case.weights"),
123+
interface = "data.frame",
124+
protect = c("x", "y", "case.weights"),
125125
func = c(pkg = "ranger", fun = "ranger"),
126126
defaults =
127127
list(
@@ -148,8 +148,8 @@ set_fit(
148148
eng = "ranger",
149149
mode = "regression",
150150
value = list(
151-
interface = "formula",
152-
protect = c("formula", "data", "case.weights"),
151+
interface = "data.frame",
152+
protect = c("x", "y", "case.weights"),
153153
func = c(pkg = "ranger", fun = "ranger"),
154154
defaults =
155155
list(

man/as_matrix.Rd

Lines changed: 0 additions & 18 deletions
This file was deleted.

man/rand_forest.Rd

Lines changed: 6 additions & 6 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_rand_forest.R

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ test_that('primary arguments', {
1616
mtry_spark <- translate(mtry %>% set_engine("spark"))
1717
expect_equal(mtry_ranger$method$fit$args,
1818
list(
19-
formula = expr(missing_arg()),
20-
data = expr(missing_arg()),
19+
x = expr(missing_arg()),
20+
y = expr(missing_arg()),
2121
case.weights = expr(missing_arg()),
2222
mtry = new_empty_quosure(4),
2323
num.threads = 1,
@@ -47,8 +47,8 @@ test_that('primary arguments', {
4747
trees_spark <- translate(trees %>% set_engine("spark"))
4848
expect_equal(trees_ranger$method$fit$args,
4949
list(
50-
formula = expr(missing_arg()),
51-
data = expr(missing_arg()),
50+
x = expr(missing_arg()),
51+
y = expr(missing_arg()),
5252
case.weights = expr(missing_arg()),
5353
num.trees = new_empty_quosure(1000),
5454
num.threads = 1,
@@ -80,8 +80,8 @@ test_that('primary arguments', {
8080
min_n_spark <- translate(min_n %>% set_engine("spark"))
8181
expect_equal(min_n_ranger$method$fit$args,
8282
list(
83-
formula = expr(missing_arg()),
84-
data = expr(missing_arg()),
83+
x = expr(missing_arg()),
84+
y = expr(missing_arg()),
8585
case.weights = expr(missing_arg()),
8686
min.node.size = new_empty_quosure(5),
8787
num.threads = 1,
@@ -112,8 +112,8 @@ test_that('primary arguments', {
112112
mtry_v_spark <- translate(mtry_v %>% set_engine("spark"))
113113
expect_equal(mtry_v_ranger$method$fit$args,
114114
list(
115-
formula = expr(missing_arg()),
116-
data = expr(missing_arg()),
115+
x = expr(missing_arg()),
116+
y = expr(missing_arg()),
117117
case.weights = expr(missing_arg()),
118118
mtry = new_empty_quosure(varying()),
119119
num.threads = 1,
@@ -145,8 +145,8 @@ test_that('primary arguments', {
145145
trees_v_spark <- translate(trees_v %>% set_engine("spark"))
146146
expect_equal(trees_v_ranger$method$fit$args,
147147
list(
148-
formula = expr(missing_arg()),
149-
data = expr(missing_arg()),
148+
x = expr(missing_arg()),
149+
y = expr(missing_arg()),
150150
case.weights = expr(missing_arg()),
151151
num.trees = new_empty_quosure(varying()),
152152
num.threads = 1,
@@ -177,8 +177,8 @@ test_that('primary arguments', {
177177
min_n_v_spark <- translate(min_n_v %>% set_engine("spark"))
178178
expect_equal(min_n_v_ranger$method$fit$args,
179179
list(
180-
formula = expr(missing_arg()),
181-
data = expr(missing_arg()),
180+
x = expr(missing_arg()),
181+
y = expr(missing_arg()),
182182
case.weights = expr(missing_arg()),
183183
min.node.size = new_empty_quosure(varying()),
184184
num.threads = 1,
@@ -210,8 +210,8 @@ test_that('engine arguments', {
210210
ranger_imp <- rand_forest(mode = "classification")
211211
expect_equal(translate(ranger_imp %>% set_engine("ranger", importance = "impurity"))$method$fit$args,
212212
list(
213-
formula = expr(missing_arg()),
214-
data = expr(missing_arg()),
213+
x = expr(missing_arg()),
214+
y = expr(missing_arg()),
215215
case.weights = expr(missing_arg()),
216216
importance = new_empty_quosure("impurity"),
217217
num.threads = 1,
@@ -246,8 +246,8 @@ test_that('engine arguments', {
246246
translate(ranger_samp_frac %>%
247247
set_engine("ranger", sample.fraction = varying()))$method$fit$args,
248248
list(
249-
formula = expr(missing_arg()),
250-
data = expr(missing_arg()),
249+
x = expr(missing_arg()),
250+
y = expr(missing_arg()),
251251
case.weights = expr(missing_arg()),
252252
sample.fraction = new_empty_quosure(varying()),
253253
num.threads = 1,

tests/testthat/test_rand_forest_ranger.R

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,29 @@ test_that('ranger classification intervals', {
439439

440440
})
441441

442+
443+
444+
test_that('ranger and sparse matrices', {
445+
skip_if_not_installed("ranger")
446+
447+
mtcar_x <- mtcars[, -1]
448+
mtcar_mat <- as.matrix(mtcar_x)
449+
mtcar_smat <- Matrix::Matrix(mtcar_mat, sparse = TRUE)
450+
451+
rf_spec <-
452+
rand_forest(trees = 10) %>%
453+
set_engine("ranger", seed = 2) %>%
454+
set_mode("regression")
455+
456+
set.seed(1)
457+
from_df <- rf_spec %>% fit_xy(mtcar_x, mtcars$mpg)
458+
set.seed(1)
459+
from_mat <- rf_spec %>% fit_xy(mtcar_mat, mtcars$mpg)
460+
set.seed(1)
461+
from_sparse <- rf_spec %>% fit_xy(mtcar_smat, mtcars$mpg)
462+
463+
expect_equal(from_df$fit, from_mat$fit)
464+
expect_equal(from_df$fit, from_sparse$fit)
465+
466+
})
467+

0 commit comments

Comments
 (0)