Skip to content

Support glmnet models with base-R families #890

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: parsnip
Title: A Common API to Modeling and Analysis Functions
Version: 1.0.4.9001
Version: 1.0.4.9002
Authors@R: c(
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
person("Davis", "Vaughan", , "[email protected]", role = "aut"),
Expand Down Expand Up @@ -34,7 +34,7 @@ Imports:
rlang (>= 0.3.1),
stats,
tibble (>= 2.1.1),
tidyr (>= 1.0.0),
tidyr (>= 1.3.0),
utils,
vctrs (>= 0.4.1),
withr
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ S3method(has_multi_predict,workflow)
S3method(multi_predict,"_C5.0")
S3method(multi_predict,"_earth")
S3method(multi_predict,"_elnet")
S3method(multi_predict,"_glmnetfit")
S3method(multi_predict,"_lognet")
S3method(multi_predict,"_multnet")
S3method(multi_predict,"_torch_mlp")
Expand All @@ -27,25 +28,30 @@ S3method(multi_predict_args,model_fit)
S3method(multi_predict_args,workflow)
S3method(nullmodel,default)
S3method(predict,"_elnet")
S3method(predict,"_glmnetfit")
S3method(predict,"_lognet")
S3method(predict,"_multnet")
S3method(predict,censoring_model_reverse_km)
S3method(predict,model_fit)
S3method(predict,model_spec)
S3method(predict,nullmodel)
S3method(predict_class,"_glmnetfit")
S3method(predict_class,"_lognet")
S3method(predict_class,"_multnet")
S3method(predict_class,model_fit)
S3method(predict_classprob,"_glmnetfit")
S3method(predict_classprob,"_lognet")
S3method(predict_classprob,"_multnet")
S3method(predict_classprob,model_fit)
S3method(predict_confint,model_fit)
S3method(predict_hazard,model_fit)
S3method(predict_linear_pred,model_fit)
S3method(predict_numeric,"_elnet")
S3method(predict_numeric,"_glmnetfit")
S3method(predict_numeric,model_fit)
S3method(predict_quantile,model_fit)
S3method(predict_raw,"_elnet")
S3method(predict_raw,"_glmnetfit")
S3method(predict_raw,"_lognet")
S3method(predict_raw,"_multnet")
S3method(predict_raw,model_fit)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# parsnip (development version)

* glmnet models fitted with base-R family objects are now supported for `linear_reg()`, `logistic_reg()`, and `multinomial_reg()` (#890).

* Made `fit()` behave consistently with respect to missingness in the classification setting. Previously, `fit()` erroneously raised an error about the class of the outcome when there were no complete cases, and now always passes along complete cases to be handled by the modeling function (#888).

Expand Down
27 changes: 23 additions & 4 deletions R/glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,17 @@ multi_predict_glmnet <- function(object,
}
}

model_type <- class(object$spec)[1]

if (object$spec$mode == "classification") {
if (is.null(type)) {
type <- "class"
}
if (!(type %in% c("class", "prob", "link", "raw"))) {
rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.")
}
if (type == "prob") {
if (type == "prob" |
model_type == "logistic_reg") {
dots$type <- "response"
} else {
dots$type <- type
Expand All @@ -114,13 +117,13 @@ multi_predict_glmnet <- function(object,
pred <- predict(object, new_data = new_data, type = "raw",
opts = dots, penalty = penalty, multi = TRUE)

model_type <- class(object$spec)[1]

res <- switch(
model_type,
"linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty),
"logistic_reg" = format_glmnet_multi_logistic_reg(pred,
penalty = penalty,
type = dots$type,
type = type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A note to help speed up @topepo's review: format_glmnet_multi_logistic_reg() used to take whichever type was passed along to glmnet, and its internals read:

if (type == "class") {
  # ...
} else {
  # ...
}

That helper is now supplied the "parsnip" type instead, in this case one of "class" or "prob", so that the above conditional has only one possible value for the else. The format_glmnet_multi_multinom_reg() helper does this already.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@simonpcouch Thanks for adding this! 🙌

lvl = object$lvl),
"multinom_reg" = format_glmnet_multi_multinom_reg(pred,
penalty = penalty,
Expand All @@ -132,9 +135,25 @@ multi_predict_glmnet <- function(object,
res
}

#' @export
predict._glmnetfit <- predict_glmnet

# -------------------------------------------------------------------------
#' @export
predict_numeric._glmnetfit <- predict_numeric_glmnet

#' @export
predict_class._glmnetfit <- predict_class_glmnet

#' @export
predict_classprob._glmnetfit <- predict_classprob_glmnet

#' @export
predict_raw._glmnetfit <- predict_raw_glmnet

#' @export
multi_predict._glmnetfit <- multi_predict_glmnet

# -------------------------------------------------------------------------

set_glmnet_penalty_path <- function(x) {
if (any(names(x$eng_args) == "path_values")) {
Expand Down
45 changes: 26 additions & 19 deletions R/logistic_reg.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,34 +184,41 @@ predict._lognet <- predict_glmnet
multi_predict._lognet <- multi_predict_glmnet

format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
param_key <- tibble(group = colnames(pred), penalty = penalty)

type <- rlang::arg_match(type, c("class", "prob"))

penalty_key <- tibble(s = colnames(pred), penalty = penalty)

pred <- as_tibble(pred)
pred$.row <- 1:nrow(pred)
pred <- gather(pred, group, .pred_class, -.row)
pred$.row <- seq_len(nrow(pred))
pred <- tidyr::pivot_longer(pred, -.row, names_to = "s", values_to = ".pred")

if (type == "class") {
pred[[".pred_class"]] <- factor(pred[[".pred_class"]], levels = lvl)
pred <- pred %>%
dplyr::mutate(.pred_class = dplyr::if_else(.pred >= 0.5, lvl[2], lvl[1]),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apart from this line adding the translation from probabilities to classes, the changes to this function are only to make use of dplyr and tidyr

.pred_class = factor(.pred_class, levels = lvl),
.keep = "unused")
} else {
if (type == "response") {
pred[[".pred2"]] <- 1 - pred[[".pred_class"]]
names(pred) <- c(".row", "group", paste0(".pred_", rev(lvl)))
pred <- pred[, c(".row", "group", paste0(".pred_", lvl))]
}
pred <- pred %>%
dplyr::mutate(.pred_class_2 = 1 - .pred) %>%
rlang::set_names(c(".row", "s", paste0(".pred_", rev(lvl)))) %>%
dplyr::select(c(".row", "s", paste0(".pred_", lvl)))
}

if (utils::packageVersion("dplyr") >= "1.0.99.9000") {
pred <- full_join(param_key, pred, by = "group", multiple = "all")
pred <- dplyr::full_join(penalty_key, pred, by = "s", multiple = "all")
} else {
pred <- full_join(param_key, pred, by = "group")
pred <- dplyr::full_join(penalty_key, pred, by = "s")
}
pred$group <- NULL
pred <- arrange(pred, .row, penalty)
.row <- pred$.row
pred$.row <- NULL
pred <- split(pred, .row)
names(pred) <- NULL
tibble(.pred = pred)
}

pred <- pred %>%
dplyr::select(-s) %>%
dplyr::arrange(penalty) %>%
tidyr::nest(.by = .row, .key = ".pred") %>%
dplyr::select(-.row)

pred
}

#' @export
predict_class._lognet <- predict_class_glmnet
Expand Down
3 changes: 2 additions & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ utils::globalVariables(
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
"compute_intercept", "remove_intercept", "estimate", "term",
"call_info", "component", "component_id", "func", "tunable", "label",
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts", "protect"
"pkg", ".order", "item", "tunable", "has_ext", "id", "weights", "has_wts",
"protect", "s"
)
)

Expand Down