Skip to content

Commit 81d9536

Browse files
Merge pull request #1121 from tidymodels/sparse-input
let `fit_xy()` take dgCMatrix input
2 parents 4acfa5e + 1971f05 commit 81d9536

File tree

10 files changed

+231
-12
lines changed

10 files changed

+231
-12
lines changed

DESCRIPTION

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ Imports:
3232
prettyunits,
3333
purrr (>= 1.0.0),
3434
rlang (>= 1.1.0),
35+
sparsevctrs (>= 0.1.0.9000),
3536
stats,
3637
tibble (>= 2.1.1),
3738
tidyr (>= 1.3.0),
@@ -52,6 +53,7 @@ Suggests:
5253
LiblineaR,
5354
MASS,
5455
Matrix,
56+
methods,
5557
mgcv,
5658
modeldata,
5759
nlme,
@@ -77,4 +79,6 @@ Config/testthat/edition: 3
7779
Encoding: UTF-8
7880
LazyData: true
7981
Roxygen: list(markdown = TRUE)
82+
Remotes:
83+
r-lib/sparsevctrs
8084
RoxygenNote: 7.3.2

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).
4+
35
* Transitioned package errors and warnings to use cli (#1147 and #1148 by
46
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
57
#1161).

R/convert_data.R

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,21 @@ maybe_matrix <- function(x) {
374374
"converted to numeric matrix: {non_num_cols}.")
375375
rlang::abort(msg)
376376
}
377-
x <- as.matrix(x)
377+
x <- maybe_sparse_matrix(x)
378378
}
379379
# leave alone if matrix or sparse matrix
380380
x
381381
}
382382

383+
maybe_sparse_matrix <- function(x) {
384+
if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) {
385+
res <- sparsevctrs::coerce_to_sparse_matrix(x)
386+
} else {
387+
res <- as.matrix(x)
388+
}
389+
res
390+
}
391+
383392
#' @rdname maybe_matrix
384393
#' @export
385394
maybe_data_frame <- function(x) {

R/fit.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@
5555
#' a "reverse Kaplan-Meier" curve that models the probability of censoring. This
5656
#' may be used later to compute inverse probability censoring weights for
5757
#' performance measures.
58+
#'
59+
#' Sparse data is supported, with the use of the `x` argument in `fit_xy()`. See
60+
#' `allow_sparse_x` column of [parsnip::get_encoding()] for sparse input
61+
#' compatibility.
62+
#'
5863
#' @examplesIf !parsnip:::is_cran_check()
5964
#' # Although `glm()` only has a formula interface, different
6065
#' # methods for specifying the model can be used
@@ -274,6 +279,8 @@ fit_xy.model_spec <-
274279
}
275280
}
276281

282+
x <- to_sparse_data_frame(x, object)
283+
277284
cl <- match.call(expand.dots = TRUE)
278285
eval_env <- rlang::env()
279286
eval_env$x <- x
@@ -380,7 +387,7 @@ inher <- function(x, cls, cl) {
380387

381388
check_interface <- function(formula, data, cl, model) {
382389
inher(formula, "formula", cl)
383-
inher(data, c("data.frame", "tbl_spark"), cl)
390+
inher(data, c("data.frame", "dgCMatrix", "tbl_spark"), cl)
384391

385392
# Determine the `fit()` interface
386393
form_interface <- !is.null(formula) & !is.null(data)

R/sparsevctrs.R

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
to_sparse_data_frame <- function(x, object) {
2+
if (methods::is(x, "sparseMatrix")) {
3+
if (allow_sparse(object)) {
4+
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
5+
} else {
6+
cli::cli_abort(
7+
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
8+
engine {.code {object$engine}} doesn't accept that.")
9+
}
10+
}
11+
x
12+
}

man/fit.Rd

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/sparsevctrs.md

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# sparse matrices can be passed to `fit_xy()
2+
3+
Code
4+
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
5+
Condition
6+
Error in `to_sparse_data_frame()`:
7+
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.
8+
9+
# to_sparse_data_frame() is used correctly
10+
11+
Code
12+
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
13+
Condition
14+
Error in `to_sparse_data_frame()`:
15+
! x is not sparse
16+
17+
---
18+
19+
Code
20+
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
21+
Condition
22+
Error in `to_sparse_data_frame()`:
23+
! x is spare, and sparse is not allowed
24+
25+
---
26+
27+
Code
28+
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
29+
Condition
30+
Error in `to_sparse_data_frame()`:
31+
! x is spare, and sparse is allowed
32+
33+
# maybe_sparse_matrix() is used correctly
34+
35+
Code
36+
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
37+
Condition
38+
Error in `maybe_sparse_matrix()`:
39+
! sparse vectors detected
40+
41+
---
42+
43+
Code
44+
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
45+
Condition
46+
Error in `maybe_sparse_matrix()`:
47+
! no sparse vectors detected
48+
49+
---
50+
51+
Code
52+
fit_xy(spec, x = as.data.frame(mtcars)[, -1], y = as.data.frame(mtcars)[, 1])
53+
Condition
54+
Error in `maybe_sparse_matrix()`:
55+
! no sparse vectors detected
56+
57+
---
58+
59+
Code
60+
fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[,
61+
1])
62+
Condition
63+
Error in `maybe_sparse_matrix()`:
64+
! no sparse vectors detected
65+

tests/testthat/helper-objects.R

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,30 @@ is_tf_ok <- function() {
2424
}
2525
res
2626
}
27+
28+
# ------------------------------------------------------------------------------
29+
# For sparse tibble testing
30+
31+
sparse_hotel_rates <- function() {
32+
# 99.2 sparsity
33+
hotel_rates <- modeldata::hotel_rates
34+
35+
prefix_colnames <- function(x, prefix) {
36+
colnames(x) <- paste(colnames(x), prefix, sep = "_")
37+
x
38+
}
39+
40+
dummies_country <- hardhat::fct_encode_one_hot(hotel_rates$country)
41+
dummies_company <- hardhat::fct_encode_one_hot(hotel_rates$company)
42+
dummies_agent <- hardhat::fct_encode_one_hot(hotel_rates$agent)
43+
44+
res <- dplyr::bind_cols(
45+
hotel_rates["avg_price_per_room"],
46+
prefix_colnames(dummies_country, "country"),
47+
prefix_colnames(dummies_company, "company"),
48+
prefix_colnames(dummies_agent, "agent")
49+
)
50+
51+
res <- as.matrix(res)
52+
Matrix::Matrix(res, sparse = TRUE)
53+
}

tests/testthat/test-rand_forest_ranger.R

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -475,16 +475,6 @@ test_that('ranger and sparse matrices', {
475475

476476
expect_equal(extract_fit_engine(from_df), extract_fit_engine(from_mat))
477477
expect_equal(extract_fit_engine(from_df), extract_fit_engine(from_sparse))
478-
479-
rf_spec <-
480-
rand_forest(trees = 10) %>%
481-
set_engine("randomForest", seed = 2) %>%
482-
set_mode("regression")
483-
expect_error(
484-
rf_spec %>% fit_xy(mtcar_smat, mtcars$mpg),
485-
"Sparse matrices not supported"
486-
)
487-
488478
})
489479

490480

tests/testthat/test-sparsevctrs.R

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
test_that("sparse matrices can be passed to `fit_xy()", {
2+
skip_if_not_installed("xgboost")
3+
4+
hotel_data <- sparse_hotel_rates()
5+
6+
spec <- boost_tree() %>%
7+
set_mode("regression") %>%
8+
set_engine("xgboost")
9+
10+
expect_no_error(
11+
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
12+
)
13+
14+
spec <- linear_reg() %>%
15+
set_mode("regression") %>%
16+
set_engine("lm")
17+
18+
expect_snapshot(
19+
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]),
20+
error = TRUE
21+
)
22+
})
23+
24+
test_that("to_sparse_data_frame() is used correctly", {
25+
skip_if_not_installed("xgboost")
26+
27+
local_mocked_bindings(
28+
to_sparse_data_frame = function(x, object) {
29+
if (methods::is(x, "sparseMatrix")) {
30+
if (allow_sparse(object)) {
31+
stop("x is spare, and sparse is allowed")
32+
} else {
33+
stop("x is spare, and sparse is not allowed")
34+
}
35+
}
36+
stop("x is not sparse")
37+
}
38+
)
39+
40+
hotel_data <- sparse_hotel_rates()
41+
42+
spec <- linear_reg() %>%
43+
set_engine("lm")
44+
45+
expect_snapshot(
46+
error = TRUE,
47+
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
48+
)
49+
expect_snapshot(
50+
error = TRUE,
51+
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
52+
)
53+
54+
spec <- boost_tree() %>%
55+
set_mode("regression") %>%
56+
set_engine("xgboost")
57+
58+
expect_snapshot(
59+
error = TRUE,
60+
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
61+
)
62+
})
63+
64+
test_that("maybe_sparse_matrix() is used correctly", {
65+
skip_if_not_installed("xgboost")
66+
67+
local_mocked_bindings(
68+
maybe_sparse_matrix = function(x) {
69+
if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) {
70+
stop("sparse vectors detected")
71+
} else {
72+
stop("no sparse vectors detected")
73+
}
74+
}
75+
)
76+
77+
hotel_data <- sparse_hotel_rates()
78+
79+
spec <- boost_tree() %>%
80+
set_mode("regression") %>%
81+
set_engine("xgboost")
82+
83+
expect_snapshot(
84+
error = TRUE,
85+
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
86+
)
87+
expect_snapshot(
88+
error = TRUE,
89+
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
90+
)
91+
expect_snapshot(
92+
error = TRUE,
93+
fit_xy(spec, x = as.data.frame(mtcars)[, -1], y = as.data.frame(mtcars)[, 1])
94+
)
95+
expect_snapshot(
96+
error = TRUE,
97+
fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[, 1])
98+
)
99+
})

0 commit comments

Comments
 (0)