Skip to content

let fit_xy() take dgCMatrix input #1121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Aug 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Imports:
prettyunits,
purrr (>= 1.0.0),
rlang (>= 1.1.0),
sparsevctrs (>= 0.1.0.9000),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it uses dev version because of this bug fix: r-lib/sparsevctrs@9c22ca9

sparsevctrs will of course be merged in time for parsnip release

stats,
tibble (>= 2.1.1),
tidyr (>= 1.3.0),
Expand All @@ -52,6 +53,7 @@ Suggests:
LiblineaR,
MASS,
Matrix,
methods,
mgcv,
modeldata,
nlme,
Expand All @@ -77,4 +79,6 @@ Config/testthat/edition: 3
Encoding: UTF-8
LazyData: true
Roxygen: list(markdown = TRUE)
Remotes:
r-lib/sparsevctrs
RoxygenNote: 7.3.2
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# parsnip (development version)

* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).

* Transitioned package errors and warnings to use cli (#1147 and #1148 by
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
#1161).
Expand Down
11 changes: 10 additions & 1 deletion R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -374,12 +374,21 @@ maybe_matrix <- function(x) {
"converted to numeric matrix: {non_num_cols}.")
rlang::abort(msg)
}
x <- as.matrix(x)
x <- maybe_sparse_matrix(x)
}
# leave alone if matrix or sparse matrix
x
}

maybe_sparse_matrix <- function(x) {
if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) {
res <- sparsevctrs::coerce_to_sparse_matrix(x)
} else {
res <- as.matrix(x)
}
res
}

#' @rdname maybe_matrix
#' @export
maybe_data_frame <- function(x) {
Expand Down
9 changes: 8 additions & 1 deletion R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
#' a "reverse Kaplan-Meier" curve that models the probability of censoring. This
#' may be used later to compute inverse probability censoring weights for
#' performance measures.
#'
#' Sparse data is supported, with the use of the `x` argument in `fit_xy()`. See
#' `allow_sparse_x` column of [parsnip::get_encoding()] for sparse input
#' compatibility.
#'
#' @examplesIf !parsnip:::is_cran_check()
#' # Although `glm()` only has a formula interface, different
#' # methods for specifying the model can be used
Expand Down Expand Up @@ -274,6 +279,8 @@ fit_xy.model_spec <-
}
}

x <- to_sparse_data_frame(x, object)

cl <- match.call(expand.dots = TRUE)
eval_env <- rlang::env()
eval_env$x <- x
Expand Down Expand Up @@ -380,7 +387,7 @@ inher <- function(x, cls, cl) {

check_interface <- function(formula, data, cl, model) {
inher(formula, "formula", cl)
inher(data, c("data.frame", "tbl_spark"), cl)
inher(data, c("data.frame", "dgCMatrix", "tbl_spark"), cl)

# Determine the `fit()` interface
form_interface <- !is.null(formula) & !is.null(data)
Expand Down
12 changes: 12 additions & 0 deletions R/sparsevctrs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
to_sparse_data_frame <- function(x, object) {
if (methods::is(x, "sparseMatrix")) {
if (allow_sparse(object)) {
x <- sparsevctrs::coerce_to_sparse_data_frame(x)
} else {
cli::cli_abort(
"{.arg x} is a sparse matrix, but {.fn {class(object)[1]}} with
engine {.code {object$engine}} doesn't accept that.")
}
}
x
}
4 changes: 4 additions & 0 deletions man/fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

65 changes: 65 additions & 0 deletions tests/testthat/_snaps/sparsevctrs.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# sparse matrices can be passed to `fit_xy()

Code
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1])
Condition
Error in `to_sparse_data_frame()`:
! `x` is a sparse matrix, but `linear_reg()` with engine `lm` doesn't accept that.

# to_sparse_data_frame() is used correctly

Code
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
Condition
Error in `to_sparse_data_frame()`:
! x is not sparse

---

Code
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
Condition
Error in `to_sparse_data_frame()`:
! x is spare, and sparse is not allowed

---

Code
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
Condition
Error in `to_sparse_data_frame()`:
! x is spare, and sparse is allowed

# maybe_sparse_matrix() is used correctly

Code
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
Condition
Error in `maybe_sparse_matrix()`:
! sparse vectors detected

---

Code
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
Condition
Error in `maybe_sparse_matrix()`:
! no sparse vectors detected

---

Code
fit_xy(spec, x = as.data.frame(mtcars)[, -1], y = as.data.frame(mtcars)[, 1])
Condition
Error in `maybe_sparse_matrix()`:
! no sparse vectors detected

---

Code
fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[,
1])
Condition
Error in `maybe_sparse_matrix()`:
! no sparse vectors detected

27 changes: 27 additions & 0 deletions tests/testthat/helper-objects.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,30 @@ is_tf_ok <- function() {
}
res
}

# ------------------------------------------------------------------------------
# For sparse tibble testing

sparse_hotel_rates <- function() {
# 99.2 sparsity
hotel_rates <- modeldata::hotel_rates

prefix_colnames <- function(x, prefix) {
colnames(x) <- paste(colnames(x), prefix, sep = "_")
x
}

dummies_country <- hardhat::fct_encode_one_hot(hotel_rates$country)
dummies_company <- hardhat::fct_encode_one_hot(hotel_rates$company)
dummies_agent <- hardhat::fct_encode_one_hot(hotel_rates$agent)

res <- dplyr::bind_cols(
hotel_rates["avg_price_per_room"],
prefix_colnames(dummies_country, "country"),
prefix_colnames(dummies_company, "company"),
prefix_colnames(dummies_agent, "agent")
)

res <- as.matrix(res)
Matrix::Matrix(res, sparse = TRUE)
}
10 changes: 0 additions & 10 deletions tests/testthat/test-rand_forest_ranger.R
Original file line number Diff line number Diff line change
Expand Up @@ -475,16 +475,6 @@ test_that('ranger and sparse matrices', {

expect_equal(extract_fit_engine(from_df), extract_fit_engine(from_mat))
expect_equal(extract_fit_engine(from_df), extract_fit_engine(from_sparse))

rf_spec <-
rand_forest(trees = 10) %>%
set_engine("randomForest", seed = 2) %>%
set_mode("regression")
expect_error(
rf_spec %>% fit_xy(mtcar_smat, mtcars$mpg),
"Sparse matrices not supported"
)

})


Expand Down
99 changes: 99 additions & 0 deletions tests/testthat/test-sparsevctrs.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
test_that("sparse matrices can be passed to `fit_xy()", {
skip_if_not_installed("xgboost")

hotel_data <- sparse_hotel_rates()

spec <- boost_tree() %>%
set_mode("regression") %>%
set_engine("xgboost")

expect_no_error(
lm_fit <- fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this didn't work, it would take quite a lot longer to run which we would notice

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would we? Does "didn't work" mean a failure or just inefficient?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inefficient. it should pop up as a "this test is running a little long" from CRAN / CMD R Check

)

spec <- linear_reg() %>%
set_mode("regression") %>%
set_engine("lm")

expect_snapshot(
lm_fit <- fit_xy(spec, x = hotel_data[1:100, -1], y = hotel_data[1:100, 1]),
error = TRUE
)
})

test_that("to_sparse_data_frame() is used correctly", {
skip_if_not_installed("xgboost")

local_mocked_bindings(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main testing strategy follows this template:

  • mock the functions that deals with sparsevctrs
  • see if we can trigger all paths inside those functions

to_sparse_data_frame = function(x, object) {
if (methods::is(x, "sparseMatrix")) {
if (allow_sparse(object)) {
stop("x is spare, and sparse is allowed")
} else {
stop("x is spare, and sparse is not allowed")
}
}
stop("x is not sparse")
}
)

hotel_data <- sparse_hotel_rates()

spec <- linear_reg() %>%
set_engine("lm")

expect_snapshot(
error = TRUE,
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
)
expect_snapshot(
error = TRUE,
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
)

spec <- boost_tree() %>%
set_mode("regression") %>%
set_engine("xgboost")

expect_snapshot(
error = TRUE,
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
)
})

test_that("maybe_sparse_matrix() is used correctly", {
skip_if_not_installed("xgboost")

local_mocked_bindings(
maybe_sparse_matrix = function(x) {
if (any(vapply(x, sparsevctrs::is_sparse_vector, logical(1)))) {
stop("sparse vectors detected")
} else {
stop("no sparse vectors detected")
}
}
)

hotel_data <- sparse_hotel_rates()

spec <- boost_tree() %>%
set_mode("regression") %>%
set_engine("xgboost")

expect_snapshot(
error = TRUE,
fit_xy(spec, x = hotel_data[, -1], y = hotel_data[, 1])
)
expect_snapshot(
error = TRUE,
fit_xy(spec, x = mtcars[, -1], y = mtcars[, 1])
)
expect_snapshot(
error = TRUE,
fit_xy(spec, x = as.data.frame(mtcars)[, -1], y = as.data.frame(mtcars)[, 1])
)
expect_snapshot(
error = TRUE,
fit_xy(spec, x = tibble::as_tibble(mtcars)[, -1], y = tibble::as_tibble(mtcars)[, 1])
)
})
Loading