Skip to content

Commit a469d0a

Browse files
authored
Merge pull request #344 from tidymodels/tidy-glmnet
better tidy glmnet methods
2 parents bed8ca9 + c769cf0 commit a469d0a

File tree

10 files changed

+174
-9
lines changed

10 files changed

+174
-9
lines changed

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ S3method(print,svm_poly)
5555
S3method(print,svm_rbf)
5656
S3method(req_pkgs,model_fit)
5757
S3method(req_pkgs,model_spec)
58+
S3method(tidy,"_elnet")
59+
S3method(tidy,"_fishnet")
60+
S3method(tidy,"_lognet")
61+
S3method(tidy,"_multnet")
5862
S3method(tidy,model_fit)
5963
S3method(tidy,nullmodel)
6064
S3method(translate,boost_tree)
@@ -234,6 +238,7 @@ importFrom(stats,.checkMFClasses)
234238
importFrom(stats,.getXlevels)
235239
importFrom(stats,as.formula)
236240
importFrom(stats,binomial)
241+
importFrom(stats,coef)
237242
importFrom(stats,delete.response)
238243
importFrom(stats,model.frame)
239244
importFrom(stats,model.matrix)

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# parsnip (development version)
22

3+
* Specific `tidy()` methods for `glmnet` models fit via `parsnip` were created so that the coefficients for the specific fitted `parsnip` model are returned.
4+
35
# parsnip 0.1.2
46

57
## Breaking Changes

R/aaa.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ utils::globalVariables(
4040
"neighbors", ".submodels", "has_submodel", "max_neighbor", "max_penalty",
4141
"max_terms", "max_tree", "model", "name", "num_terms", "penalty", "trees",
4242
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
43-
"compute_intercept", "remove_intercept")
43+
"compute_intercept", "remove_intercept", "estimate", "term")
4444
)
4545

4646
# nocov end

R/predict.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,12 @@
8181
#' lm_model <-
8282
#' linear_reg() %>%
8383
#' set_engine("lm") %>%
84-
#' fit(mpg ~ ., data = mtcars %>% slice(11:32))
84+
#' fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))
8585
#'
8686
#' pred_cars <-
8787
#' mtcars %>%
88-
#' slice(1:10) %>%
89-
#' select(-mpg)
88+
#' dplyr::slice(1:10) %>%
89+
#' dplyr::select(-mpg)
9090
#'
9191
#' predict(lm_model, pred_cars)
9292
#'

R/tidy_glmnet.R

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#' tidy methods for glmnet models
2+
#'
3+
#' `tidy()` methods for the various `glmnet` models that return the coefficients
4+
#' for the specific penalty value used by the `parsnip` model fit.
5+
#' @param x A fitted `parsnip` model that used the `glmnet` engine.
6+
#' @param penalty A _single_ numeric value. If none is given, the value specified
7+
#' in the model specification is used.
8+
#' @param ... Not used
9+
#' @return A tibble with columns `term`, `estimate`, and `penalty`. When a
10+
#' multinomial mode is used, an additional `class` column is included.
11+
#' @importFrom stats coef
12+
#' @export
13+
tidy._elnet <- function(x, penalty = NULL, ...) {
14+
tidy_glmnet(x, penalty)
15+
}
16+
17+
#' @export
18+
#' @rdname tidy._elnet
19+
tidy._lognet <- function(x, penalty = NULL, ...) {
20+
tidy_glmnet(x, penalty)
21+
}
22+
23+
#' @export
24+
#' @rdname tidy._elnet
25+
tidy._multnet <- function(x, penalty = NULL, ...) {
26+
tidy_glmnet(x, penalty)
27+
}
28+
29+
#' @export
30+
#' @rdname tidy._elnet
31+
tidy._fishnet <- function(x, penalty = NULL, ...) {
32+
tidy_glmnet(x, penalty)
33+
}
34+
35+
## -----------------------------------------------------------------------------
36+
37+
get_glmn_coefs <- function(x, penalty = 0.01) {
38+
res <- coef(x, s = penalty)
39+
res <- as.matrix(res)
40+
colnames(res) <- "estimate"
41+
rn <- rownames(res)
42+
res <- tibble::as_tibble(res) %>% mutate(term = rn, penalty = penalty)
43+
res <- dplyr::select(res, term, estimate, penalty)
44+
if (is.list(res$estimate)) {
45+
res$estimate <- purrr::map(res$estimate, ~ as_tibble(as.matrix(.x), rownames = "term"))
46+
res <- tidyr::unnest(res, cols = c(estimate), names_repair = "minimal")
47+
names(res) <- c("class", "term", "estimate", "penalty")
48+
}
49+
res
50+
}
51+
52+
tidy_glmnet <- function(x, penalty = NULL, ...) {
53+
check_installs(x$spec)
54+
load_libs(x$spec, quiet = TRUE, attach = TRUE)
55+
if (is.null(penalty)) {
56+
if (isTRUE(is.numeric(x$spec$args$penalty))){
57+
penalty <- x$spec$args$penalty
58+
} else {
59+
rlang::abort("Please pick a single value of `penalty`.")
60+
}
61+
}
62+
get_glmn_coefs(x$fit, penalty = penalty)
63+
}

R/zzz.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
.onLoad <- function(libname, pkgname) {
44
s3_register("broom::tidy", "model_fit")
55
s3_register("broom::tidy", "nullmodel")
6+
s3_register("broom::tidy", "_elnet")
7+
s3_register("broom::tidy", "_lognet")
8+
s3_register("broom::tidy", "_multnet")
9+
s3_register("broom::tidy", "_fishnet")
610
}
711

812

man/predict.model_fit.Rd

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/tidy._elnet.Rd

Lines changed: 33 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,8 @@ test_that('glmnet probabilities, mulitiple lambda', {
143143

144144
for (i in seq_along(mult_class_res$.pred)) {
145145
expect_equal(
146-
mult_class %>% slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")),
147-
mult_class_res %>% slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred"))
146+
mult_class %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")),
147+
mult_class_res %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred"))
148148
)
149149
}
150150

tests/testthat/test_tidy_glmnet.R

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
context("tidy glmnet models")
2+
3+
test_that('linear regression', {
4+
skip_if_not_installed("glmnet")
5+
6+
ps_mod <-
7+
linear_reg(penalty = .1) %>%
8+
set_engine("glmnet") %>%
9+
fit(mpg ~ ., data = mtcars)
10+
11+
ps_coefs <- tidy(ps_mod)
12+
gn_coefs <- as.matrix(coef(ps_mod$fit, s = .1))
13+
for(i in ps_coefs$term) {
14+
expect_equal(ps_coefs$estimate[ps_coefs$term == i], gn_coefs[i,1])
15+
}
16+
})
17+
18+
test_that('logistic regression', {
19+
skip_if_not_installed("glmnet")
20+
21+
data(two_class_dat, package = "modeldata")
22+
23+
ps_mod <-
24+
logistic_reg(penalty = .1) %>%
25+
set_engine("glmnet") %>%
26+
fit(Class ~ ., data = two_class_dat)
27+
28+
ps_coefs <- tidy(ps_mod)
29+
gn_coefs <- as.matrix(coef(ps_mod$fit, s = .1))
30+
for(i in ps_coefs$term) {
31+
expect_equal(ps_coefs$estimate[ps_coefs$term == i], gn_coefs[i,1])
32+
}
33+
})
34+
35+
test_that('multinomial regression', {
36+
skip_if_not_installed("glmnet")
37+
38+
data(penguins, package = "modeldata")
39+
40+
ps_mod <-
41+
multinom_reg(penalty = .01) %>%
42+
set_engine("glmnet") %>%
43+
fit(species ~ ., data = penguins)
44+
45+
ps_coefs <- tidy(ps_mod)
46+
gn_coefs <- coef(ps_mod$fit, s = .01)
47+
gn_coefs <- purrr::map(gn_coefs, as.matrix)
48+
for(i in unique(ps_coefs$term)) {
49+
for(j in unique(ps_coefs$class)) {
50+
expect_equal(
51+
ps_coefs$estimate[ps_coefs$term == i & ps_coefs$class == j],
52+
gn_coefs[[j]][i,1]
53+
)
54+
}
55+
}
56+
})
57+
58+

0 commit comments

Comments
 (0)