Skip to content

Commit fc21c9e

Browse files
authored
Merge pull request #486 from tidymodels/path-values
Get correct coefs for ridge regression
2 parents bc125e9 + 8ff0a28 commit fc21c9e

13 files changed

+256
-71
lines changed

NEWS.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,33 @@
11
# parsnip (development version)
22

3-
* `generics::required_pkgs()` was extended for `parsnip` objects.
4-
5-
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
3+
## Model Specification Changes
64

75
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine (#424) and the `kernlab` engine (#438), and the `LiblineaR` engine is available for `logistic_reg()` as well (#429). These models can use sparse matrices via `fit_xy()` (#447) and have a `tidy` method (#474).
86

7+
* For models with `glmnet` engines:
8+
9+
- A single value is required for `penalty` (either a single numeric value or a value of `tune()`) (#481).
10+
- A special argument called `path_values` can be used to set the `lambda` path as a specific set of numbers (independent of the value of `penalty`). A pure ridge regression models (i.e., `mixture = 1`) will generate incorrect values if the path does not include zero. See issue #431 for discussion (#486).
11+
12+
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
13+
914
* New model specification `survival_reg()` for the new mode `"censored regression"` (#444). `surv_reg()` is now soft-deprecated (#448).
1015

1116
* New model specification `proportional_hazards()` for the `"censored regression"` mode (#451).
1217

18+
## Other Changes
19+
1320
* Re-licensed package from GPL-2 to MIT. See [consent from copyright holders here](https://github.com/tidymodels/parsnip/issues/462).
1421

1522
* `set_mode()` now checks if `mode` is compatible with the model class, similar to `new_model_spec()` (@jtlandis, #467).
1623

1724
* Re-organized model documentation for `update` methods (#479).
1825

26+
27+
28+
* `generics::required_pkgs()` was extended for `parsnip` objects.
29+
30+
1931

2032
# parsnip 0.1.5
2133

R/linear_reg.R

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,23 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
107107
x <- translate.default(x, engine, ...)
108108

109109
if (engine == "glmnet") {
110-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
111-
x$method$fit$args$lambda <- NULL
110+
check_glmnet_penalty(x)
111+
if (any(names(x$eng_args) == "path_values")) {
112+
# Since we decouple the parsnip `penalty` argument from being the same
113+
# as the glmnet `lambda` value, `path_values` allows users to set the
114+
# path differently from the default that glmnet uses. See
115+
# https://github.com/tidymodels/parsnip/issues/431
116+
x$method$fit$args$lambda <- x$eng_args$path_values
117+
x$eng_args$path_values <- NULL
118+
x$method$fit$args$path_values <- NULL
119+
} else {
120+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
121+
x$method$fit$args$lambda <- NULL
122+
}
112123
# Since the `fit` information is gone for the penalty, we need to have an
113124
# evaluated value for the parameter.
114125
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
115-
check_glmnet_penalty(x)
116126
}
117-
118127
x
119128
}
120129

R/logistic_reg.R

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,23 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
108108
arg_vals <- x$method$fit$args
109109
arg_names <- names(arg_vals)
110110

111-
112111
if (engine == "glmnet") {
113-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
114-
arg_vals$lambda <- NULL
112+
check_glmnet_penalty(x)
113+
if (any(names(x$eng_args) == "path_values")) {
114+
# Since we decouple the parsnip `penalty` argument from being the same
115+
# as the glmnet `lambda` value, `path_values` allows users to set the
116+
# path differently from the default that glmnet uses. See
117+
# https://github.com/tidymodels/parsnip/issues/431
118+
x$method$fit$args$lambda <- x$eng_args$path_values
119+
x$eng_args$path_values <- NULL
120+
x$method$fit$args$path_values <- NULL
121+
} else {
122+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
123+
x$method$fit$args$lambda <- NULL
124+
}
115125
# Since the `fit` information is gone for the penalty, we need to have an
116126
# evaluated value for the parameter.
117127
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
118-
check_glmnet_penalty(x)
119128
}
120129

121130
if (engine == "LiblineaR") {
@@ -134,11 +143,8 @@ translate.logistic_reg <- function(x, engine = x$engine, ...) {
134143
rlang::abort("For the LiblineaR engine, mixture must be 0 or 1.")
135144
}
136145
}
137-
146+
x$method$fit$args <- arg_vals
138147
}
139-
140-
x$method$fit$args <- arg_vals
141-
142148
x
143149
}
144150

R/misc.R

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,12 @@ stan_conf_int <- function(object, newdata) {
324324
}
325325

326326
check_glmnet_penalty <- function(x) {
327-
if (length(x$args$penalty) != 1) {
327+
pen <- rlang::eval_tidy(x$args$penalty)
328+
329+
if (length(pen) != 1) {
328330
rlang::abort(c(
329331
"For the glmnet engine, `penalty` must be a single number (or a value of `tune()`).",
330-
glue::glue("There are {length(x$args$penalty)} values for `penalty`."),
332+
glue::glue("There are {length(pen)} values for `penalty`."),
331333
"To try multiple values for total regularization, use the tune package.",
332334
"To predict multiple penalties, use `multi_predict()`"
333335
))

man/linear_reg.Rd

Lines changed: 32 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/logistic_reg.Rd

Lines changed: 32 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/multinom_reg.Rd

Lines changed: 32 additions & 9 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/linear-reg.Rmd

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,29 @@ linear_reg(penalty = 0.1) %>%
2121
translate()
2222
```
2323

24-
For `glmnet` models, the full regularization path is always fit regardless of the
25-
value given to `penalty`. Also, there is the option to pass multiple values (or
26-
no values) to the `penalty` argument. When using the `predict()` method in these
27-
cases, the return value depends on the value of `penalty`. When using
28-
`predict()`, only a single value of the penalty can be used. When predicting on
29-
multiple penalties, the `multi_predict()` function can be used. It returns a
30-
tibble with a list column called `.pred` that contains a tibble with all of the
31-
penalty results.
24+
The glmnet engine requires a single value for the `penalty` argument (a number
25+
or `tune()`), but the full regularization path is always fit
26+
regardless of the value given to `penalty`. To pass in a custom sequence of
27+
values for glmnet's `lambda`, use the argument `path_values` in `set_engine()`.
28+
This will assign the value of the glmnet `lambda` parameter without disturbing
29+
the value given of `linear_reg(penalty)`. For example:
30+
31+
```{r glmnet-path}
32+
linear_reg(penalty = .1) %>%
33+
set_engine("glmnet", path_values = c(0, 10^seq(-10, 1, length.out = 20))) %>%
34+
translate()
35+
```
36+
37+
When fitting a pure ridge regression model (i.e., `penalty = 0`), we _strongly
38+
suggest_ that you pass in a vector for `path_values` that includes zero. See
39+
[issue #431](https://github.com/tidymodels/parsnip/issues/431) for a discussion.
40+
41+
When using `predict()`, the single `penalty` value used for prediction is the
42+
one specified in `linear_reg()`.
43+
44+
To predict on multiple penalties, use the `multi_predict()` function.
45+
This function returns a tibble with a list column called `.pred` containing
46+
all of the penalty results.
3247

3348
## stan
3449

man/rmd/logistic-reg.Rmd

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,30 @@ logistic_reg(penalty = 0.1) %>%
2222
translate()
2323
```
2424

25-
For `glmnet` models, the full regularization path is always fit regardless of the
26-
value given to `penalty`. Also, there is the option to pass multiple values (or
27-
no values) to the `penalty` argument. When using the `predict()` method in these
28-
cases, the return value depends on the value of `penalty`. When using
29-
`predict()`, only a single value of the penalty can be used. When predicting on
30-
multiple penalties, the `multi_predict()` function can be used. It returns a
31-
tibble with a list column called `.pred` that contains a tibble with all of the
32-
penalty results.
25+
The glmnet engine requires a single value for the `penalty` argument (a number
26+
or `tune()`), but the full regularization path is always fit
27+
regardless of the value given to `penalty`. To pass in a custom sequence of
28+
values for glmnet's `lambda`, use the argument `path_values` in `set_engine()`.
29+
This will assign the value of the glmnet `lambda` parameter without disturbing
30+
the value given of `logistic_reg(penalty)`. For example:
31+
32+
```{r glmnet-path}
33+
logistic_reg(penalty = .1) %>%
34+
set_engine("glmnet", path_values = c(0, 10^seq(-10, 1, length.out = 20))) %>%
35+
translate()
36+
```
37+
38+
When fitting a pure ridge regression model (i.e., `penalty = 0`), we _strongly
39+
suggest_ that you pass in a vector for `path_values` that includes zero. See
40+
[issue #431](https://github.com/tidymodels/parsnip/issues/431) for a discussion.
41+
42+
When using `predict()`, the single `penalty` value used for prediction is the
43+
one specified in `logistic_reg()`.
44+
45+
To predict on multiple penalties, use the `multi_predict()` function.
46+
This function returns a tibble with a list column called `.pred` containing
47+
all of the penalty results.
48+
3349

3450
## LiblineaR
3551

man/rmd/multinom-reg.Rmd

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,31 @@ multinom_reg(penalty = 0.1) %>%
1414
translate()
1515
```
1616

17-
For `glmnet` models, the full regularization path is always fit regardless of the
18-
value given to `penalty`. Also, there is the option to pass multiple values (or
19-
no values) to the `penalty` argument. When using the `predict()` method in these
20-
cases, the return value depends on the value of `penalty`. When using
21-
`predict()`, only a single value of the penalty can be used. When predicting on
22-
multiple penalties, the `multi_predict()` function can be used. It returns a
23-
tibble with a list column called `.pred` that contains a tibble with all of the
24-
penalty results.
17+
The glmnet engine requires a single value for the `penalty` argument (a number
18+
or `tune()`), but the full regularization path is always fit
19+
regardless of the value given to `penalty`. To pass in a custom sequence of
20+
values for glmnet's `lambda`, use the argument `path_values` in `set_engine()`.
21+
This will assign the value of the glmnet `lambda` parameter without disturbing
22+
the value given of `multinom_reg(penalty)`. For example:
23+
24+
25+
```{r glmnet-path}
26+
multinom_reg(penalty = .1) %>%
27+
set_engine("glmnet", path_values = c(0, 10^seq(-10, 1, length.out = 20))) %>%
28+
translate()
29+
```
30+
31+
When fitting a pure ridge regression model (i.e., `penalty = 0`), we _strongly
32+
suggest_ that you pass in a vector for `path_values` that includes zero. See
33+
[issue #431](https://github.com/tidymodels/parsnip/issues/431) for a discussion.
34+
35+
When using `predict()`, the single `penalty` value used for prediction is the
36+
one specified in `multinom_reg()`.
37+
38+
To predict on multiple penalties, use the `multi_predict()` function.
39+
This function returns a tibble with a list column called `.pred` containing
40+
all of the penalty results.
41+
2542

2643
## nnet
2744

0 commit comments

Comments
 (0)