Skip to content

Commit 88e23c4

Browse files
authored
Merge pull request #373 from tidymodels/sparsity
Changes to allow sparsity predictor representations
2 parents cf62381 + 9ddf191 commit 88e23c4

28 files changed

+374
-124
lines changed

DESCRIPTION

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ Imports:
3131
prettyunits,
3232
vctrs (>= 0.2.0)
3333
Roxygen: list(markdown = TRUE)
34-
RoxygenNote: 7.1.1
34+
RoxygenNote: 7.1.1.9000
3535
Suggests:
3636
testthat,
3737
knitr,
@@ -46,9 +46,10 @@ Suggests:
4646
kernlab,
4747
kknn,
4848
randomForest,
49-
ranger,
49+
ranger (>= 0.12.0),
5050
rpart,
5151
MASS,
5252
nlme,
5353
modeldata,
54-
liquidSVM
54+
liquidSVM,
55+
Matrix

NAMESPACE

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ export(linear_reg)
130130
export(logistic_reg)
131131
export(make_classes)
132132
export(mars)
133+
export(maybe_data_frame)
134+
export(maybe_matrix)
133135
export(mlp)
134136
export(model_printer)
135137
export(multi_predict)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
* `show_engines()` will provide information on the current set for a model.
44

5+
* For three models (`glmnet`, `xgboost`, and `ranger`), enable sparse matrix use via `fit_xy()` (#373).
6+
57
# parsnip 0.1.3
68

79
* A `glance()` method for `model_fit` objects was added (#325)

R/aaa_models.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,8 @@ check_encodings <- function(x) {
819819
}
820820
req_args <- list(predictor_indicators = rlang::na_chr,
821821
compute_intercept = rlang::na_lgl,
822-
remove_intercept = rlang::na_lgl)
822+
remove_intercept = rlang::na_lgl,
823+
allow_sparse_x = rlang::na_lgl)
823824

824825
missing_args <- setdiff(names(req_args), names(x))
825826
if (length(missing_args) > 0) {
@@ -896,7 +897,8 @@ get_encoding <- function(model) {
896897
model = model,
897898
predictor_indicators = "traditional",
898899
compute_intercept = TRUE,
899-
remove_intercept = TRUE
900+
remove_intercept = TRUE,
901+
allow_sparse_x = FALSE
900902
) %>%
901903
dplyr::select(model, engine, mode, predictor_indicators,
902904
compute_intercept, remove_intercept)

R/arguments.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ make_xy_call <- function(object, target) {
191191
switch(
192192
target,
193193
none = rlang::expr(x),
194-
data.frame = rlang::expr(as.data.frame(x)),
195-
matrix = rlang::expr(as.matrix(x)),
194+
data.frame = rlang::expr(maybe_data_frame(x)),
195+
matrix = rlang::expr(maybe_matrix(x)),
196196
rlang::abort(glue::glue("Invalid data type target: {target}."))
197197
)
198198

R/boost_tree.R

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -290,11 +290,8 @@ xgb_train <- function(
290290
min_child_weight = 1, gamma = 0, subsample = 1, validation = 0,
291291
early_stop = NULL, ...) {
292292

293-
if (length(levels(y)) > 2) {
294-
num_class <- length(levels(y))
295-
} else {
296-
num_class <- NULL
297-
}
293+
num_class <- length(levels(y))
294+
298295
if (!is.numeric(validation) || validation < 0 || validation >= 1) {
299296
rlang::abort("`validation` should be on [0, 1).")
300297
}
@@ -311,36 +308,17 @@ xgb_train <- function(
311308
if (is.numeric(y)) {
312309
loss <- "reg:squarederror"
313310
} else {
314-
lvl <- levels(y)
315-
y <- as.numeric(y) - 1
316-
if (length(lvl) == 2) {
311+
if (num_class == 2) {
317312
loss <- "binary:logistic"
318313
} else {
319314
loss <- "multi:softprob"
320315
}
321316
}
322317

323-
if (is.data.frame(x)) {
324-
x <- as.matrix(x) # maybe use model.matrix here?
325-
}
326-
327318
n <- nrow(x)
328319
p <- ncol(x)
329320

330-
if (!inherits(x, "xgb.DMatrix")) {
331-
if (validation > 0) {
332-
trn_index <- sample(1:n, size = floor(n * validation) + 1)
333-
wlist <-
334-
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
335-
x <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
336-
337-
} else {
338-
x <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
339-
wlist <- list(training = x)
340-
}
341-
} else {
342-
xgboost::setinfo(x, "label", y)
343-
}
321+
x <- as_xgb_data(x, y, validation)
344322

345323
# translate `subsample` and `colsample_bytree` to be on (0, 1] if not
346324
if (subsample > 1) {
@@ -366,17 +344,15 @@ xgb_train <- function(
366344
subsample = subsample
367345
)
368346

369-
# eval if contains expressions?
370-
371347
main_args <- list(
372-
data = quote(x),
373-
watchlist = quote(wlist),
348+
data = quote(x$data),
349+
watchlist = quote(x$watchlist),
374350
params = arg_list,
375351
nrounds = nrounds,
376352
objective = loss,
377353
early_stopping_rounds = early_stop
378354
)
379-
if (!is.null(num_class)) {
355+
if (!is.null(num_class) && num_class > 2) {
380356
main_args$num_class <- num_class
381357
}
382358

@@ -399,7 +375,7 @@ xgb_train <- function(
399375
#' @importFrom stats binomial
400376
xgb_pred <- function(object, newdata, ...) {
401377
if (!inherits(newdata, "xgb.DMatrix")) {
402-
newdata <- as.matrix(newdata)
378+
newdata <- maybe_matrix(newdata)
403379
newdata <- xgboost::xgb.DMatrix(data = newdata, missing = NA)
404380
}
405381

@@ -415,6 +391,37 @@ xgb_pred <- function(object, newdata, ...) {
415391
x
416392
}
417393

394+
395+
as_xgb_data <- function(x, y, validation = 0, ...) {
396+
lvls <- levels(y)
397+
n <- nrow(x)
398+
399+
if (is.data.frame(x)) {
400+
x <- as.matrix(x)
401+
}
402+
403+
if (is.factor(y)) {
404+
y <- as.numeric(y) - 1
405+
}
406+
407+
if (!inherits(x, "xgb.DMatrix")) {
408+
if (validation > 0) {
409+
trn_index <- sample(1:n, size = floor(n * (1 - validation)) + 1)
410+
wlist <-
411+
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
412+
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
413+
414+
} else {
415+
dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
416+
wlist <- list(training = dat)
417+
}
418+
} else {
419+
dat <- xgboost::setinfo(x, "label", y)
420+
wlist <- list(training = dat)
421+
}
422+
423+
list(data = dat, watchlist = wlist)
424+
}
418425
#' @importFrom purrr map_df
419426
#' @export
420427
#' @rdname multi_predict

R/boost_tree_data.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ set_encoding(
9494
options = list(
9595
predictor_indicators = "one_hot",
9696
compute_intercept = FALSE,
97-
remove_intercept = TRUE
97+
remove_intercept = TRUE,
98+
allow_sparse_x = TRUE
9899
)
99100
)
100101

@@ -143,7 +144,8 @@ set_encoding(
143144
options = list(
144145
predictor_indicators = "one_hot",
145146
compute_intercept = FALSE,
146-
remove_intercept = TRUE
147+
remove_intercept = TRUE,
148+
allow_sparse_x = TRUE
147149
)
148150
)
149151

@@ -250,7 +252,8 @@ set_encoding(
250252
options = list(
251253
predictor_indicators = "none",
252254
compute_intercept = FALSE,
253-
remove_intercept = FALSE
255+
remove_intercept = FALSE,
256+
allow_sparse_x = FALSE
254257
)
255258
)
256259

@@ -384,7 +387,8 @@ set_encoding(
384387
options = list(
385388
predictor_indicators = "none",
386389
compute_intercept = FALSE,
387-
remove_intercept = FALSE
390+
remove_intercept = FALSE,
391+
allow_sparse_x = FALSE
388392
)
389393
)
390394

@@ -408,7 +412,8 @@ set_encoding(
408412
options = list(
409413
predictor_indicators = "none",
410414
compute_intercept = FALSE,
411-
remove_intercept = FALSE
415+
remove_intercept = FALSE,
416+
allow_sparse_x = FALSE
412417
)
413418
)
414419

R/convert_data.R

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,31 @@ check_dup_names <- function(x, y) {
323323
)
324324
invisible(NULL)
325325
}
326+
327+
## -----------------------------------------------------------------------------
328+
329+
#' Fuzzy conversions
330+
#'
331+
#' These are substitutes for `as.matrix()` and `as.data.frame()` that leave
332+
#' a sparse matrix as-is.
333+
#' @param x A data frame, matrix, or sparse matrix.
334+
#' @return A data frame, matrix, or sparse matrix.
335+
#' @export
336+
maybe_matrix <- function(x) {
337+
inher(x, c("data.frame", "matrix", "dgCMatrix"), cl = match.call())
338+
if (is.data.frame(x)) {
339+
x <- as.matrix(x)
340+
}
341+
# leave alone if matrix or sparse matrix
342+
x
343+
}
344+
345+
#' @rdname maybe_matrix
346+
#' @export
347+
maybe_data_frame <- function(x) {
348+
if (!inherits(x, "dgCMatrix")) {
349+
x <- as.data.frame(x)
350+
}
351+
x
352+
}
353+

R/decision_tree_data.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ set_encoding(
5555
options = list(
5656
predictor_indicators = "none",
5757
compute_intercept = FALSE,
58-
remove_intercept = FALSE
58+
remove_intercept = FALSE,
59+
allow_sparse_x = FALSE
5960
)
6061
)
6162

@@ -78,7 +79,8 @@ set_encoding(
7879
options = list(
7980
predictor_indicators = "none",
8081
compute_intercept = FALSE,
81-
remove_intercept = FALSE
82+
remove_intercept = FALSE,
83+
allow_sparse_x = FALSE
8284
)
8385
)
8486

@@ -187,7 +189,8 @@ set_encoding(
187189
options = list(
188190
predictor_indicators = "none",
189191
compute_intercept = FALSE,
190-
remove_intercept = FALSE
192+
remove_intercept = FALSE,
193+
allow_sparse_x = FALSE
191194
)
192195
)
193196

@@ -285,7 +288,8 @@ set_encoding(
285288
options = list(
286289
predictor_indicators = "none",
287290
compute_intercept = FALSE,
288-
remove_intercept = FALSE
291+
remove_intercept = FALSE,
292+
allow_sparse_x = FALSE
289293
)
290294
)
291295

@@ -310,7 +314,8 @@ set_encoding(
310314
options = list(
311315
predictor_indicators = "none",
312316
compute_intercept = FALSE,
313-
remove_intercept = FALSE
317+
remove_intercept = FALSE,
318+
allow_sparse_x = FALSE
314319
)
315320
)
316321

0 commit comments

Comments
 (0)