Skip to content

Commit 92b2bd5

Browse files
authored
Support glmnet models with base-R families (#890)
* add methods for `glmnetfit` objects there objects result from using base-R families * use `type = "response"` for all log reg prediction because `type = "class"` is not available in glmnet for `glmnetfit` objects * add NEWS bullet * bump version for tests in extratests
1 parent 070f1b2 commit 92b2bd5

File tree

6 files changed

+60
-26
lines changed

6 files changed

+60
-26
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.0.4.9001
3+
Version: 1.0.4.9002
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
@@ -34,7 +34,7 @@ Imports:
3434
rlang (>= 0.3.1),
3535
stats,
3636
tibble (>= 2.1.1),
37-
tidyr (>= 1.0.0),
37+
tidyr (>= 1.3.0),
3838
utils,
3939
vctrs (>= 0.4.1),
4040
withr

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ S3method(has_multi_predict,workflow)
1616
S3method(multi_predict,"_C5.0")
1717
S3method(multi_predict,"_earth")
1818
S3method(multi_predict,"_elnet")
19+
S3method(multi_predict,"_glmnetfit")
1920
S3method(multi_predict,"_lognet")
2021
S3method(multi_predict,"_multnet")
2122
S3method(multi_predict,"_torch_mlp")
@@ -27,25 +28,30 @@ S3method(multi_predict_args,model_fit)
2728
S3method(multi_predict_args,workflow)
2829
S3method(nullmodel,default)
2930
S3method(predict,"_elnet")
31+
S3method(predict,"_glmnetfit")
3032
S3method(predict,"_lognet")
3133
S3method(predict,"_multnet")
3234
S3method(predict,censoring_model_reverse_km)
3335
S3method(predict,model_fit)
3436
S3method(predict,model_spec)
3537
S3method(predict,nullmodel)
38+
S3method(predict_class,"_glmnetfit")
3639
S3method(predict_class,"_lognet")
3740
S3method(predict_class,"_multnet")
3841
S3method(predict_class,model_fit)
42+
S3method(predict_classprob,"_glmnetfit")
3943
S3method(predict_classprob,"_lognet")
4044
S3method(predict_classprob,"_multnet")
4145
S3method(predict_classprob,model_fit)
4246
S3method(predict_confint,model_fit)
4347
S3method(predict_hazard,model_fit)
4448
S3method(predict_linear_pred,model_fit)
4549
S3method(predict_numeric,"_elnet")
50+
S3method(predict_numeric,"_glmnetfit")
4651
S3method(predict_numeric,model_fit)
4752
S3method(predict_quantile,model_fit)
4853
S3method(predict_raw,"_elnet")
54+
S3method(predict_raw,"_glmnetfit")
4955
S3method(predict_raw,"_lognet")
5056
S3method(predict_raw,"_multnet")
5157
S3method(predict_raw,model_fit)

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# parsnip (development version)
22

3+
* glmnet models fitted with base-R family objects are now supported for `linear_reg()`, `logistic_reg()`, and `multinomial_reg()` (#890).
34

45
* Made `fit()` behave consistently with respect to missingness in the classification setting. Previously, `fit()` erroneously raised an error about the class of the outcome when there were no complete cases, and now always passes along complete cases to be handled by the modeling function (#888).
56

R/glmnet.R

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,17 @@ multi_predict_glmnet <- function(object,
9797
}
9898
}
9999

100+
model_type <- class(object$spec)[1]
101+
100102
if (object$spec$mode == "classification") {
101103
if (is.null(type)) {
102104
type <- "class"
103105
}
104106
if (!(type %in% c("class", "prob", "link", "raw"))) {
105107
rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.")
106108
}
107-
if (type == "prob") {
109+
if (type == "prob" |
110+
model_type == "logistic_reg") {
108111
dots$type <- "response"
109112
} else {
110113
dots$type <- type
@@ -114,13 +117,13 @@ multi_predict_glmnet <- function(object,
114117
pred <- predict(object, new_data = new_data, type = "raw",
115118
opts = dots, penalty = penalty, multi = TRUE)
116119

117-
model_type <- class(object$spec)[1]
120+
118121
res <- switch(
119122
model_type,
120123
"linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty),
121124
"logistic_reg" = format_glmnet_multi_logistic_reg(pred,
122125
penalty = penalty,
123-
type = dots$type,
126+
type = type,
124127
lvl = object$lvl),
125128
"multinom_reg" = format_glmnet_multi_multinom_reg(pred,
126129
penalty = penalty,
@@ -132,9 +135,25 @@ multi_predict_glmnet <- function(object,
132135
res
133136
}
134137

138+
#' @export
139+
predict._glmnetfit <- predict_glmnet
135140

136-
# -------------------------------------------------------------------------
141+
#' @export
142+
predict_numeric._glmnetfit <- predict_numeric_glmnet
137143

144+
#' @export
145+
predict_class._glmnetfit <- predict_class_glmnet
146+
147+
#' @export
148+
predict_classprob._glmnetfit <- predict_classprob_glmnet
149+
150+
#' @export
151+
predict_raw._glmnetfit <- predict_raw_glmnet
152+
153+
#' @export
154+
multi_predict._glmnetfit <- multi_predict_glmnet
155+
156+
# -------------------------------------------------------------------------
138157

139158
set_glmnet_penalty_path <- function(x) {
140159
if (any(names(x$eng_args) == "path_values")) {

R/logistic_reg.R

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -184,34 +184,41 @@ predict._lognet <- predict_glmnet
184184
multi_predict._lognet <- multi_predict_glmnet
185185

186186
format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
187-
param_key <- tibble(group = colnames(pred), penalty = penalty)
187+
188+
type <- rlang::arg_match(type, c("class", "prob"))
189+
190+
penalty_key <- tibble(s = colnames(pred), penalty = penalty)
191+
188192
pred <- as_tibble(pred)
189-
pred$.row <- 1:nrow(pred)
190-
pred <- gather(pred, group, .pred_class, -.row)
193+
pred$.row <- seq_len(nrow(pred))
194+
pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred")
195+
191196
if (type == "class") {
192-
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = lvl)
197+
pred <- pred %>%
198+
dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]),
199+
.pred_class = factor(.pred_class, levels = lvl),
200+
.keep = "unused")
193201
} else {
194-
if (type == "response") {
195-
pred[[".pred2"]] <- 1 - pred[[".pred_class"]]
196-
names(pred) <- c(".row", "group", paste0(".pred_", rev(lvl)))
197-
pred <- pred[, c(".row", "group", paste0(".pred_", lvl))]
198-
}
202+
pred <- pred %>%
203+
dplyr::mutate(.pred_class_2 = 1 - .pred) %>%
204+
rlang::set_names(c(".row", "s", paste0(".pred_", rev(lvl)))) %>%
205+
dplyr::select(c(".row", "s", paste0(".pred_", lvl)))
199206
}
207+
200208
if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
201-
pred <- full_join(param_key, pred, by = "group", multiple = "all")
209+
pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all")
202210
} else {
203-
pred <- full_join(param_key, pred, by = "group")
211+
pred <- dplyr::full_join(penalty_key, pred, by = "s")
204212
}
205-
pred$group <- NULL
206-
pred <- arrange(pred, .row, penalty)
207-
.row <- pred$.row
208-
pred$.row <- NULL
209-
pred <- split(pred, .row)
210-
names(pred) <- NULL
211-
tibble(.pred = pred)
212-
}
213213

214+
pred <- pred %>%
215+
dplyr::select(-s) %>%
216+
dplyr::arrange(penalty) %>%
217+
tidyr::nest(.by = .row, .key = ".pred") %>%
218+
dplyr::select(-.row)
214219

220+
pred
221+
}
215222

216223
#' @export
217224
predict_class._lognet <- predict_class_glmnet

R/parsnip-package.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ utils::globalVariables(
4242
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
4343
"compute_intercept", "remove_intercept", "estimate", "term",
4444
"call_info", "component", "component_id", "func", "tunable", "label",
45-
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts", "protect"
45+
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts",
46+
"protect", "s"
4647
)
4748
)
4849

0 commit comments

Comments
 (0)