Skip to content

Commit 659b5ad

Browse files
committed
added encoding field for sparse matrices for tidymodels/tidymodels#42
1 parent 0cb413c commit 659b5ad

16 files changed

+139
-60
lines changed

R/aaa_models.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,7 +819,8 @@ check_encodings <- function(x) {
819819
}
820820
req_args <- list(predictor_indicators = rlang::na_chr,
821821
compute_intercept = rlang::na_lgl,
822-
remove_intercept = rlang::na_lgl)
822+
remove_intercept = rlang::na_lgl,
823+
allow_sparse_x = rlang::na_lgl)
823824

824825
missing_args <- setdiff(names(req_args), names(x))
825826
if (length(missing_args) > 0) {
@@ -896,7 +897,8 @@ get_encoding <- function(model) {
896897
model = model,
897898
predictor_indicators = "traditional",
898899
compute_intercept = TRUE,
899-
remove_intercept = TRUE
900+
remove_intercept = TRUE,
901+
allow_sparse_x = FALSE
900902
) %>%
901903
dplyr::select(model, engine, mode, predictor_indicators,
902904
compute_intercept, remove_intercept)

R/boost_tree_data.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,8 @@ set_encoding(
9494
options = list(
9595
predictor_indicators = "one_hot",
9696
compute_intercept = FALSE,
97-
remove_intercept = TRUE
97+
remove_intercept = TRUE,
98+
allow_sparse_x = TRUE
9899
)
99100
)
100101

@@ -143,7 +144,8 @@ set_encoding(
143144
options = list(
144145
predictor_indicators = "one_hot",
145146
compute_intercept = FALSE,
146-
remove_intercept = TRUE
147+
remove_intercept = TRUE,
148+
allow_sparse_x = TRUE
147149
)
148150
)
149151

@@ -250,7 +252,8 @@ set_encoding(
250252
options = list(
251253
predictor_indicators = "none",
252254
compute_intercept = FALSE,
253-
remove_intercept = FALSE
255+
remove_intercept = FALSE,
256+
allow_sparse_x = FALSE
254257
)
255258
)
256259

@@ -384,7 +387,8 @@ set_encoding(
384387
options = list(
385388
predictor_indicators = "none",
386389
compute_intercept = FALSE,
387-
remove_intercept = FALSE
390+
remove_intercept = FALSE,
391+
allow_sparse_x = FALSE
388392
)
389393
)
390394

@@ -408,7 +412,8 @@ set_encoding(
408412
options = list(
409413
predictor_indicators = "none",
410414
compute_intercept = FALSE,
411-
remove_intercept = FALSE
415+
remove_intercept = FALSE,
416+
allow_sparse_x = FALSE
412417
)
413418
)
414419

R/decision_tree_data.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ set_encoding(
5555
options = list(
5656
predictor_indicators = "none",
5757
compute_intercept = FALSE,
58-
remove_intercept = FALSE
58+
remove_intercept = FALSE,
59+
allow_sparse_x = FALSE
5960
)
6061
)
6162

@@ -78,7 +79,8 @@ set_encoding(
7879
options = list(
7980
predictor_indicators = "none",
8081
compute_intercept = FALSE,
81-
remove_intercept = FALSE
82+
remove_intercept = FALSE,
83+
allow_sparse_x = FALSE
8284
)
8385
)
8486

@@ -187,7 +189,8 @@ set_encoding(
187189
options = list(
188190
predictor_indicators = "none",
189191
compute_intercept = FALSE,
190-
remove_intercept = FALSE
192+
remove_intercept = FALSE,
193+
allow_sparse_x = FALSE
191194
)
192195
)
193196

@@ -285,7 +288,8 @@ set_encoding(
285288
options = list(
286289
predictor_indicators = "none",
287290
compute_intercept = FALSE,
288-
remove_intercept = FALSE
291+
remove_intercept = FALSE,
292+
allow_sparse_x = FALSE
289293
)
290294
)
291295

@@ -310,7 +314,8 @@ set_encoding(
310314
options = list(
311315
predictor_indicators = "none",
312316
compute_intercept = FALSE,
313-
remove_intercept = FALSE
317+
remove_intercept = FALSE,
318+
allow_sparse_x = FALSE
314319
)
315320
)
316321

R/fit.R

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -342,9 +342,18 @@ check_interface <- function(formula, data, cl, model) {
342342
}
343343

344344
check_xy_interface <- function(x, y, cl, model) {
345-
# TODO Do we need a model spec attribute that is something like
346-
# 'allow_sparse' to make this conditional on that?
347-
inher(x, c("data.frame", "matrix", "dgCMatrix"), cl)
345+
346+
sparse_ok <- allow_sparse(model)
347+
sparse_x <- inherits(x, "dgCMatrix")
348+
if (!sparse_ok & sparse_x) {
349+
rlang::abort("Sparse matrices not supported by this model/engine combination.")
350+
}
351+
352+
if (sparse_ok) {
353+
inher(x, c("data.frame", "matrix", "dgCMatrix"), cl)
354+
} else {
355+
inher(x, c("data.frame", "matrix"), cl)
356+
}
348357

349358
# `y` can be a vector (which is not a class), or a factor (which is not a vector)
350359
if (!is.null(y) && !is.vector(y))
@@ -359,22 +368,33 @@ check_xy_interface <- function(x, y, cl, model) {
359368
)
360369
)
361370

362-
# Determine the `fit()` interface
363-
# TODO conditional here too?
364-
matrix_interface <- !is.null(x) & !is.null(y) && (is.matrix(x) | inherits(x, "dgCMatrix"))
371+
372+
if (sparse_ok) {
373+
matrix_interface <- !is.null(x) & !is.null(y) && (is.matrix(x) | sparse_x)
374+
} else {
375+
matrix_interface <- !is.null(x) & !is.null(y) && is.matrix(x)
376+
}
377+
365378
df_interface <- !is.null(x) & !is.null(y) && is.data.frame(x)
366379

367-
if (inherits(model, "surv_reg") &&
368-
(matrix_interface | df_interface))
380+
if (inherits(model, "surv_reg") && (matrix_interface | df_interface)) {
369381
rlang::abort("Survival models must use the formula interface.")
382+
}
370383

371-
if (matrix_interface)
384+
if (matrix_interface) {
372385
return("data.frame")
373-
if (df_interface)
386+
}
387+
if (df_interface) {
374388
return("data.frame")
389+
}
375390
rlang::abort("Error when checking the interface")
376391
}
377392

393+
allow_sparse <- function(x) {
394+
res <- get_from_env(paste0(class(x)[1], "_encoding"))
395+
all(res$allow_sparse_x[res$engine == x$engine])
396+
}
397+
378398
#' @method print model_fit
379399
#' @export
380400
print.model_fit <- function(x, ...) {

R/linear_reg_data.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ set_encoding(
2626
options = list(
2727
predictor_indicators = "traditional",
2828
compute_intercept = TRUE,
29-
remove_intercept = TRUE
29+
remove_intercept = TRUE,
30+
allow_sparse_x = FALSE
3031
)
3132
)
3233

@@ -132,7 +133,8 @@ set_encoding(
132133
options = list(
133134
predictor_indicators = "traditional",
134135
compute_intercept = TRUE,
135-
remove_intercept = TRUE
136+
remove_intercept = TRUE,
137+
allow_sparse_x = TRUE
136138
)
137139
)
138140

@@ -212,7 +214,8 @@ set_encoding(
212214
options = list(
213215
predictor_indicators = "traditional",
214216
compute_intercept = TRUE,
215-
remove_intercept = TRUE
217+
remove_intercept = TRUE,
218+
allow_sparse_x = FALSE
216219
)
217220
)
218221

@@ -340,7 +343,8 @@ set_encoding(
340343
options = list(
341344
predictor_indicators = "traditional",
342345
compute_intercept = TRUE,
343-
remove_intercept = TRUE
346+
remove_intercept = TRUE,
347+
allow_sparse_x = FALSE
344348
)
345349
)
346350

@@ -405,7 +409,8 @@ set_encoding(
405409
options = list(
406410
predictor_indicators = "traditional",
407411
compute_intercept = TRUE,
408-
remove_intercept = TRUE
412+
remove_intercept = TRUE,
413+
allow_sparse_x = FALSE
409414
)
410415
)
411416

R/logistic_reg_data.R

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ set_encoding(
2626
options = list(
2727
predictor_indicators = "traditional",
2828
compute_intercept = TRUE,
29-
remove_intercept = TRUE
29+
remove_intercept = TRUE,
30+
allow_sparse_x = FALSE
3031
)
3132
)
3233

@@ -151,7 +152,8 @@ set_encoding(
151152
options = list(
152153
predictor_indicators = "traditional",
153154
compute_intercept = TRUE,
154-
remove_intercept = TRUE
155+
remove_intercept = TRUE,
156+
allow_sparse_x = TRUE
155157
)
156158
)
157159

@@ -274,7 +276,8 @@ set_encoding(
274276
options = list(
275277
predictor_indicators = "traditional",
276278
compute_intercept = TRUE,
277-
remove_intercept = TRUE
279+
remove_intercept = TRUE,
280+
allow_sparse_x = FALSE
278281
)
279282
)
280283

@@ -346,7 +349,8 @@ set_encoding(
346349
options = list(
347350
predictor_indicators = "traditional",
348351
compute_intercept = TRUE,
349-
remove_intercept = TRUE
352+
remove_intercept = TRUE,
353+
allow_sparse_x = FALSE
350354
)
351355
)
352356

@@ -414,7 +418,8 @@ set_encoding(
414418
options = list(
415419
predictor_indicators = "traditional",
416420
compute_intercept = TRUE,
417-
remove_intercept = TRUE
421+
remove_intercept = TRUE,
422+
allow_sparse_x = FALSE
418423
)
419424
)
420425

R/mars_data.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ set_encoding(
5454
options = list(
5555
predictor_indicators = "none",
5656
compute_intercept = FALSE,
57-
remove_intercept = FALSE
57+
remove_intercept = FALSE,
58+
allow_sparse_x = FALSE
5859
)
5960
)
6061

@@ -77,7 +78,8 @@ set_encoding(
7778
options = list(
7879
predictor_indicators = "none",
7980
compute_intercept = FALSE,
80-
remove_intercept = FALSE
81+
remove_intercept = FALSE,
82+
allow_sparse_x = FALSE
8183
)
8284
)
8385

R/mlp_data.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ set_encoding(
7272
options = list(
7373
predictor_indicators = "traditional",
7474
compute_intercept = TRUE,
75-
remove_intercept = TRUE
75+
remove_intercept = TRUE,
76+
allow_sparse_x = FALSE
7677
)
7778
)
7879

@@ -95,7 +96,8 @@ set_encoding(
9596
options = list(
9697
predictor_indicators = "traditional",
9798
compute_intercept = TRUE,
98-
remove_intercept = TRUE
99+
remove_intercept = TRUE,
100+
allow_sparse_x = FALSE
99101
)
100102
)
101103

@@ -242,7 +244,8 @@ set_encoding(
242244
options = list(
243245
predictor_indicators = "traditional",
244246
compute_intercept = TRUE,
245-
remove_intercept = TRUE
247+
remove_intercept = TRUE,
248+
allow_sparse_x = FALSE
246249
)
247250
)
248251

@@ -265,7 +268,8 @@ set_encoding(
265268
options = list(
266269
predictor_indicators = "traditional",
267270
compute_intercept = TRUE,
268-
remove_intercept = TRUE
271+
remove_intercept = TRUE,
272+
allow_sparse_x = FALSE
269273
)
270274
)
271275

R/multinom_reg_data.R

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ set_encoding(
4444
options = list(
4545
predictor_indicators = "traditional",
4646
compute_intercept = TRUE,
47-
remove_intercept = TRUE
47+
remove_intercept = TRUE,
48+
allow_sparse_x = TRUE
4849
)
4950
)
5051

@@ -146,7 +147,8 @@ set_encoding(
146147
options = list(
147148
predictor_indicators = "traditional",
148149
compute_intercept = TRUE,
149-
remove_intercept = TRUE
150+
remove_intercept = TRUE,
151+
allow_sparse_x = FALSE
150152
)
151153
)
152154

@@ -220,7 +222,8 @@ set_encoding(
220222
options = list(
221223
predictor_indicators = "traditional",
222224
compute_intercept = TRUE,
223-
remove_intercept = TRUE
225+
remove_intercept = TRUE,
226+
allow_sparse_x = FALSE
224227
)
225228
)
226229

@@ -294,7 +297,8 @@ set_encoding(
294297
options = list(
295298
predictor_indicators = "traditional",
296299
compute_intercept = TRUE,
297-
remove_intercept = TRUE
300+
remove_intercept = TRUE,
301+
allow_sparse_x = FALSE
298302
)
299303
)
300304

0 commit comments

Comments
 (0)