Skip to content

Commit 154c1ab

Browse files
juliasilgetopepo
andauthored
Add LiblineaR engine to logistic_reg() (#429)
* Add LiblineaR engine for logistic_reg() * Update docs, tests, NEWS for LiblineaR logistic_reg() * Update NEWS * Add docs about LiblineaR regularizing intercept * Change test to engine arg of bias * Update man/rmd/logistic-reg.Rmd Co-authored-by: Max Kuhn <[email protected]> * Redocument * Test logistic_reg for varying() penalty Co-authored-by: Max Kuhn <[email protected]>
1 parent ce24784 commit 154c1ab

File tree

7 files changed

+374
-16
lines changed

7 files changed

+374
-16
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ Imports:
3232
prettyunits,
3333
vctrs (>= 0.2.0)
3434
Roxygen: list(markdown = TRUE)
35-
RoxygenNote: 7.1.1.9000
35+
RoxygenNote: 7.1.1.9001
3636
Suggests:
3737
testthat,
3838
knitr,

NEWS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
* The `liquidSVM` engine for `svm_rbf()` was deprecated due to that package's removal from CRAN. (#425)
44

5-
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine. (#424)
5+
* A new linear SVM model `svm_linear()` is now available with the `LiblineaR` engine (#424), and the `LiblineaR` engine is available for `logistic_reg()` as well (#429).
66

77
# parsnip 0.1.5
88

R/logistic_reg.R

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,22 @@
2020
#' @param mode A single character string for the type of model.
2121
#' The only possible value for this model is "classification".
2222
#' @param penalty A non-negative number representing the total
23-
#' amount of regularization (`glmnet`, `keras`, and `spark` only).
23+
#' amount of regularization (`glmnet`, `LiblineaR`, `keras`, and `spark` only).
2424
#' For `keras` models, this corresponds to purely L2 regularization
25-
#' (aka weight decay) while the other models can be a combination
25+
#' (aka weight decay) while the other models can be either or a combination
2626
#' of L1 and L2 (depending on the value of `mixture`).
2727
#' @param mixture A number between zero and one (inclusive) that is the
2828
#' proportion of L1 regularization (i.e. lasso) in the model. When
2929
#' `mixture = 1`, it is a pure lasso model while `mixture = 0` indicates that
30-
#' ridge regression is being used. (`glmnet` and `spark` only).
30+
#' ridge regression is being used. (`glmnet`, `LiblineaR`, and `spark` only).
31+
#' For `LiblineaR` models, `mixture` must be exactly 0 or 1 only.
3132
#' @details
3233
#' For `logistic_reg()`, the mode will always be "classification".
3334
#'
3435
#' The model can be created using the `fit()` function using the
3536
#' following _engines_:
3637
#' \itemize{
37-
#' \item \pkg{R}: `"glm"` (the default) or `"glmnet"`
38+
#' \item \pkg{R}: `"glm"` (the default), `"glmnet"`, or `"LiblineaR"`
3839
#' \item \pkg{Stan}: `"stan"`
3940
#' \item \pkg{Spark}: `"spark"`
4041
#' \item \pkg{keras}: `"keras"`
@@ -101,7 +102,45 @@ print.logistic_reg <- function(x, ...) {
101102
}
102103

103104
#' @export
104-
translate.logistic_reg <- translate.linear_reg
105+
translate.logistic_reg <- function(x, engine = x$engine, ...) {
106+
x <- translate.default(x, engine, ...)
107+
108+
# slightly cleaner code using
109+
arg_vals <- x$method$fit$args
110+
arg_names <- names(arg_vals)
111+
112+
113+
if (engine == "glmnet") {
114+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
115+
arg_vals$lambda <- NULL
116+
# Since the `fit` information is gone for the penalty, we need to have an
117+
# evaluated value for the parameter.
118+
x$args$penalty <- rlang::eval_tidy(x$args$penalty)
119+
}
120+
121+
if (engine == "LiblineaR") {
122+
# convert parameter arguments
123+
new_penalty <- rlang::eval_tidy(x$args$penalty)
124+
if (is.numeric(new_penalty))
125+
arg_vals$cost <- rlang::new_quosure(1 / new_penalty, env = rlang::empty_env())
126+
127+
if (any(arg_names == "type")) {
128+
if (is.numeric(quo_get_expr(arg_vals$type)))
129+
if (quo_get_expr(x$args$mixture) == 0) {
130+
arg_vals$type <- 0 ## ridge
131+
} else if (quo_get_expr(x$args$mixture) == 1) {
132+
arg_vals$type <- 6 ## lasso
133+
} else {
134+
rlang::abort("For the LiblineaR engine, mixture must be 0 or 1.")
135+
}
136+
}
137+
138+
}
139+
140+
x$method$fit$args <- arg_vals
141+
142+
x
143+
}
105144

106145
# ------------------------------------------------------------------------------
107146

@@ -169,6 +208,16 @@ check_args.logistic_reg <- function(object) {
169208
if (is.numeric(args$mixture) && length(args$mixture) > 1)
170209
rlang::abort("Only one value of `mixture` is allowed.")
171210

211+
if (object$engine == "LiblineaR") {
212+
if(is.numeric(args$mixture) && !args$mixture %in% 0:1)
213+
rlang::abort(c("For the LiblineaR engine, mixture must be 0 or 1.",
214+
"Choose a pure ridge model with `mixture = 0`.",
215+
"Choose a pure lasso model with `mixture = 1`.",
216+
"The Liblinear engine does not support other values."))
217+
if(all(is.numeric(args$penalty)) && !all(args$penalty > 0))
218+
rlang::abort("For the LiblineaR engine, penalty must be > 0.")
219+
}
220+
172221
invisible(object)
173222
}
174223

@@ -346,3 +395,12 @@ predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
346395
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
347396
}
348397

398+
# ------------------------------------------------------------------------------
399+
400+
liblinear_preds <- function(results, object) {
401+
results$predictions
402+
}
403+
404+
liblinear_probs <- function(results, object) {
405+
as_tibble(results$probabilities)
406+
}

R/logistic_reg_data.R

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,104 @@ set_pred(
233233

234234
# ------------------------------------------------------------------------------
235235

236+
set_model_engine("logistic_reg", "classification", "LiblineaR")
237+
set_dependency("logistic_reg", "LiblineaR", "LiblineaR")
238+
239+
set_fit(
240+
model = "logistic_reg",
241+
eng = "LiblineaR",
242+
mode = "classification",
243+
value = list(
244+
interface = "matrix",
245+
protect = c("x", "y", "wi"),
246+
data = c(x = "data", y = "target"),
247+
func = c(pkg = "LiblineaR", fun = "LiblineaR"),
248+
defaults = list(verbose = FALSE)
249+
)
250+
)
251+
252+
set_encoding(
253+
model = "logistic_reg",
254+
eng = "LiblineaR",
255+
mode = "classification",
256+
options = list(
257+
predictor_indicators = "none",
258+
compute_intercept = FALSE,
259+
remove_intercept = FALSE,
260+
allow_sparse_x = FALSE
261+
)
262+
)
263+
264+
set_model_arg(
265+
model = "logistic_reg",
266+
eng = "LiblineaR",
267+
parsnip = "penalty",
268+
original = "cost",
269+
func = list(pkg = "dials", fun = "penalty"),
270+
has_submodel = TRUE
271+
)
272+
273+
set_model_arg(
274+
model = "logistic_reg",
275+
eng = "LiblineaR",
276+
parsnip = "mixture",
277+
original = "type",
278+
func = list(pkg = "dials", fun = "mixture"),
279+
has_submodel = FALSE
280+
)
281+
282+
set_pred(
283+
model = "logistic_reg",
284+
eng = "LiblineaR",
285+
mode = "classification",
286+
type = "class",
287+
value = list(
288+
pre = NULL,
289+
post = liblinear_preds,
290+
func = c(fun = "predict"),
291+
args =
292+
list(
293+
object = quote(object$fit),
294+
newx = expr(as.matrix(new_data))
295+
)
296+
)
297+
)
298+
299+
set_pred(
300+
model = "logistic_reg",
301+
eng = "LiblineaR",
302+
mode = "classification",
303+
type = "prob",
304+
value = list(
305+
pre = NULL,
306+
post = liblinear_probs,
307+
func = c(fun = "predict"),
308+
args =
309+
list(
310+
object = quote(object$fit),
311+
newx = expr(as.matrix(new_data)),
312+
proba = TRUE
313+
)
314+
)
315+
)
316+
317+
set_pred(
318+
model = "logistic_reg",
319+
eng = "LiblineaR",
320+
mode = "classification",
321+
type = "raw",
322+
value = list(
323+
pre = NULL,
324+
post = NULL,
325+
func = c(fun = "predict"),
326+
args = list(
327+
object = quote(object$fit),
328+
newx = quote(new_data))
329+
)
330+
)
331+
332+
# ------------------------------------------------------------------------------
333+
236334
set_model_engine("logistic_reg", "classification", "spark")
237335
set_dependency("logistic_reg", "spark", "sparklyr")
238336

man/logistic_reg.Rd

Lines changed: 32 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/rmd/logistic-reg.Rmd

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ multiple penalties, the `multi_predict()` function can be used. It returns a
3333
tibble with a list column called `.pred` that contains a tibble with all of the
3434
penalty results.
3535

36+
## LiblineaR
37+
38+
```{r liblinear-reg}
39+
logistic_reg() %>%
40+
set_engine("LiblineaR") %>%
41+
set_mode("classification") %>%
42+
translate()
43+
```
44+
45+
For `LiblineaR` models, the value for `mixture` can either be 0 (for ridge) or 1
46+
(for lasso) but not other intermediate values. In the `LiblineaR` documentation,
47+
these correspond to types 0 (L2-regularized) and 6 (L1-regularized).
48+
49+
Be aware that the `LiblineaR` engine regularizes the intercept. Other
50+
regularized regression models do not, which will result in different parameter estimates.
51+
3652
## stan
3753

3854
```{r stan-reg}
@@ -81,11 +97,11 @@ get_defaults_logistic_reg <- function() {
8197
tibble::tribble(
8298
~model, ~engine, ~parsnip, ~original, ~default,
8399
"logistic_reg", "glmnet", "mixture", "alpha", get_arg("glmnet", "glmnet", "alpha"),
100+
"logistic_reg", "LiblineaR", "mixture", "type", "0",
84101
"logistic_reg", "spark", "penalty", "reg_param", get_arg("sparklyr", "ml_logistic_regression", "reg_param"),
85102
"logistic_reg", "spark", "mixture", "elastic_net_param", get_arg("sparklyr", "ml_logistic_regression", "elastic_net_param"),
86103
"logistic_reg", "keras", "penalty", "penalty", get_arg("parsnip", "keras_mlp", "penalty"),
87104
)
88105
}
89106
convert_args("logistic_reg")
90107
```
91-

0 commit comments

Comments
 (0)