Skip to content

Commit 6cb9225

Browse files
Consolidate glmnet predict methods (#868)
* consolidate glmnet predict methods * Update R/glmnet.R Co-authored-by: Emil Hvitfeldt <[email protected]> * add suggestions from code review --------- Co-authored-by: Emil Hvitfeldt <[email protected]>
1 parent ff20477 commit 6cb9225

File tree

4 files changed

+168
-278
lines changed

4 files changed

+168
-278
lines changed

R/glmnet.R

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# glmnet call stack using `predict()` when object has
2+
# classes "_<glmnet-class>" and "model_fit":
3+
#
4+
# predict()
5+
# predict._<glmnet-class>(penalty = NULL)
6+
# predict_glmnet(penalty = NULL) <-- checks and sets penalty
7+
# predict.model_fit() <-- checks for extra vars in ...
8+
# predict_numeric()
9+
# predict_numeric._<glmnet-class>()
10+
# predict_numeric_glmnet()
11+
# predict_numeric.model_fit()
12+
# predict.<glmnet-class>()
13+
14+
15+
# glmnet call stack using `multi_predict` when object has
16+
# classes "_<glmnet-class>" and "model_fit":
17+
#
18+
# multi_predict()
19+
# multi_predict._<glmnet-class>(penalty = NULL)
20+
# predict._<glmnet-class>(multi = TRUE)
21+
# predict_glmnet(multi = TRUE) <-- checks and sets penalty
22+
# predict.model_fit() <-- checks for extra vars in ...
23+
# predict_raw()
24+
# predict_raw._<glmnet-class>()
25+
# predict_raw_glmnet()
26+
# predict_raw.model_fit(opts = list(s = penalty))
27+
# predict.<glmnet-class>()
28+
29+
30+
predict_glmnet <- function(object,
31+
new_data,
32+
type = NULL,
33+
opts = list(),
34+
penalty = NULL,
35+
multi = FALSE,
36+
...) {
37+
38+
if (any(names(enquos(...)) == "newdata")) {
39+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
40+
}
41+
42+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
43+
if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
44+
penalty <- object$spec$args$penalty
45+
}
46+
47+
object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi)
48+
49+
object$spec <- eval_args(object$spec)
50+
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
51+
}
52+
53+
predict_numeric_glmnet <- function(object, new_data, ...) {
54+
if (any(names(enquos(...)) == "newdata")) {
55+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
56+
}
57+
58+
object$spec <- eval_args(object$spec)
59+
predict_numeric.model_fit(object, new_data = new_data, ...)
60+
}
61+
62+
predict_class_glmnet <- function(object, new_data, ...) {
63+
if (any(names(enquos(...)) == "newdata")) {
64+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
65+
}
66+
67+
object$spec <- eval_args(object$spec)
68+
predict_class.model_fit(object, new_data = new_data, ...)
69+
}
70+
71+
predict_classprob_glmnet <- function(object, new_data, ...) {
72+
if (any(names(enquos(...)) == "newdata")) {
73+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
74+
}
75+
76+
object$spec <- eval_args(object$spec)
77+
predict_classprob.model_fit(object, new_data = new_data, ...)
78+
}
79+
80+
predict_raw_glmnet <- function(object, new_data, opts = list(), ...) {
81+
if (any(names(enquos(...)) == "newdata")) {
82+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
83+
}
84+
85+
object$spec <- eval_args(object$spec)
86+
87+
opts$s <- object$spec$args$penalty
88+
89+
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
90+
}
91+
92+
multi_predict_glmnet <- function(object,
93+
new_data,
94+
type = NULL,
95+
penalty = NULL,
96+
...) {
97+
98+
if (any(names(enquos(...)) == "newdata")) {
99+
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
100+
}
101+
102+
if (object$spec$mode == "classification") {
103+
if (is_quosure(penalty)) {
104+
penalty <- eval_tidy(penalty)
105+
}
106+
}
107+
108+
dots <- list(...)
109+
110+
object$spec <- eval_args(object$spec)
111+
112+
if (is.null(penalty)) {
113+
# See discussion in https://github.com/tidymodels/parsnip/issues/195
114+
if (!is.null(object$spec$args$penalty)) {
115+
penalty <- object$spec$args$penalty
116+
} else {
117+
penalty <- object$fit$lambda
118+
}
119+
}
120+
121+
if (object$spec$mode == "classification") {
122+
if (is.null(type)) {
123+
type <- "class"
124+
}
125+
if (!(type %in% c("class", "prob", "link", "raw"))) {
126+
rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.")
127+
}
128+
if (type == "prob") {
129+
dots$type <- "response"
130+
} else {
131+
dots$type <- type
132+
}
133+
}
134+
135+
pred <- predict(object, new_data = new_data, type = "raw",
136+
opts = dots, penalty = penalty, multi = TRUE)
137+
138+
model_type <- class(object$spec)[1]
139+
res <- switch(
140+
model_type,
141+
"linear_reg" = format_glmnet_multi_linear_reg(pred, penalty = penalty),
142+
"logistic_reg" = format_glmnet_multi_logistic_reg(pred,
143+
penalty = penalty,
144+
type = dots$type,
145+
lvl = object$lvl),
146+
"multinom_reg" = format_glmnet_multi_multinom_reg(pred,
147+
penalty = penalty,
148+
type = type,
149+
n_rows = nrow(new_data),
150+
lvl = object$lvl)
151+
)
152+
153+
res
154+
}

R/linear_reg.R

Lines changed: 4 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -156,94 +156,19 @@ check_args.linear_reg <- function(object) {
156156
res
157157
}
158158

159-
# ------------------------------------------------------------------------------
160-
# glmnet call stack for linear regression using `predict` when object has
161-
# classes "_elnet" and "model_fit":
162-
#
163-
# predict()
164-
# predict._elnet(penalty = NULL) <-- checks and sets penalty
165-
# predict.model_fit() <-- checks for extra vars in ...
166-
# predict_numeric()
167-
# predict_numeric._elnet()
168-
# predict_numeric.model_fit()
169-
# predict.elnet()
170-
171-
172-
# glmnet call stack for linear regression using `multi_predict` when object has
173-
# classes "_elnet" and "model_fit":
174-
#
175-
# multi_predict()
176-
# multi_predict._elnet(penalty = NULL)
177-
# predict._elnet(multi = TRUE) <-- checks and sets penalty
178-
# predict.model_fit() <-- checks for extra vars in ...
179-
# predict_raw()
180-
# predict_raw._elnet()
181-
# predict_raw.model_fit(opts = list(s = penalty))
182-
# predict.elnet()
183-
184-
185159
#' @export
186-
predict._elnet <-
187-
function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
188-
if (any(names(enquos(...)) == "newdata"))
189-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
190-
191-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
192-
if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
193-
penalty <- object$spec$args$penalty
194-
}
195-
196-
object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi)
197-
198-
object$spec <- eval_args(object$spec)
199-
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
200-
}
160+
predict._elnet <- predict_glmnet
201161

202162
#' @export
203-
predict_numeric._elnet <- function(object, new_data, ...) {
204-
if (any(names(enquos(...)) == "newdata"))
205-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
206-
207-
object$spec <- eval_args(object$spec)
208-
predict_numeric.model_fit(object, new_data = new_data, ...)
209-
}
163+
predict_numeric._elnet <- predict_numeric_glmnet
210164

211165
#' @export
212-
predict_raw._elnet <- function(object, new_data, opts = list(), ...) {
213-
if (any(names(enquos(...)) == "newdata"))
214-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
215-
216-
object$spec <- eval_args(object$spec)
217-
opts$s <- object$spec$args$penalty
218-
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
219-
}
166+
predict_raw._elnet <- predict_raw_glmnet
220167

221168
#' @export
222169
#'@rdname multi_predict
223170
#' @param penalty A numeric vector of penalty values.
224-
multi_predict._elnet <-
225-
function(object, new_data, type = NULL, penalty = NULL, ...) {
226-
if (any(names(enquos(...)) == "newdata"))
227-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
228-
229-
dots <- list(...)
230-
231-
object$spec <- eval_args(object$spec)
232-
233-
if (is.null(penalty)) {
234-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
235-
if (!is.null(object$spec$args$penalty)) {
236-
penalty <- object$spec$args$penalty
237-
} else {
238-
penalty <- object$fit$lambda
239-
}
240-
}
241-
242-
pred <- predict._elnet(object, new_data = new_data, type = "raw",
243-
opts = dots, penalty = penalty, multi = TRUE)
244-
245-
format_glmnet_multi_linear_reg(pred, penalty = penalty)
246-
}
171+
multi_predict._elnet <- multi_predict_glmnet
247172

248173
format_glmnet_multi_linear_reg <- function(pred, penalty) {
249174
param_key <- tibble(group = colnames(pred), penalty = penalty)

R/logistic_reg.R

Lines changed: 5 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -206,92 +206,14 @@ organize_glmnet_prob <- function(x, object) {
206206
res
207207
}
208208

209-
# ------------------------------------------------------------------------------
210-
# glmnet call stack for logistic regression using `predict` when object has
211-
# classes "_lognet" and "model_fit" (for class predictions):
212-
#
213-
# predict()
214-
# predict._lognet(penalty = NULL) <-- checks and sets penalty
215-
# predict.model_fit() <-- checks for extra vars in ...
216-
# predict_class()
217-
# predict_class._lognet()
218-
# predict_class.model_fit()
219-
# predict.lognet()
220-
221-
222-
# glmnet call stack for logistic regression using `multi_predict` when object has
223-
# classes "_lognet" and "model_fit" (for class predictions):
224-
#
225-
# multi_predict()
226-
# multi_predict._lognet(penalty = NULL)
227-
# predict._lognet(multi = TRUE) <-- checks and sets penalty
228-
# predict.model_fit() <-- checks for extra vars in ...
229-
# predict_raw()
230-
# predict_raw._lognet()
231-
# predict_raw.model_fit(opts = list(s = penalty))
232-
# predict.lognet()
233-
234209
# ------------------------------------------------------------------------------
235210

236211
#' @export
237-
predict._lognet <- function(object, new_data, type = NULL, opts = list(), penalty = NULL, multi = FALSE, ...) {
238-
if (any(names(enquos(...)) == "newdata"))
239-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
240-
241-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
242-
if (is.null(penalty) & !is.null(object$spec$args$penalty)) {
243-
penalty <- object$spec$args$penalty
244-
}
245-
246-
object$spec$args$penalty <- .check_glmnet_penalty_predict(penalty, object, multi)
247-
248-
object$spec <- eval_args(object$spec)
249-
predict.model_fit(object, new_data = new_data, type = type, opts = opts, ...)
250-
}
251-
212+
predict._lognet <- predict_glmnet
252213

253214
#' @export
254215
#' @rdname multi_predict
255-
multi_predict._lognet <-
256-
function(object, new_data, type = NULL, penalty = NULL, ...) {
257-
if (any(names(enquos(...)) == "newdata"))
258-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
259-
260-
if (is_quosure(penalty))
261-
penalty <- eval_tidy(penalty)
262-
263-
dots <- list(...)
264-
265-
if (is.null(penalty)) {
266-
# See discussion in https://github.com/tidymodels/parsnip/issues/195
267-
if (!is.null(object$spec$args$penalty)) {
268-
penalty <- object$spec$args$penalty
269-
} else {
270-
penalty <- object$fit$lambda
271-
}
272-
}
273-
274-
if (is.null(type))
275-
type <- "class"
276-
if (!(type %in% c("class", "prob", "link", "raw"))) {
277-
rlang::abort("`type` should be either 'class', 'link', 'raw', or 'prob'.")
278-
}
279-
if (type == "prob")
280-
dots$type <- "response"
281-
else
282-
dots$type <- type
283-
284-
object$spec <- eval_args(object$spec)
285-
pred <- predict._lognet(object, new_data = new_data, type = "raw",
286-
opts = dots, penalty = penalty, multi = TRUE)
287-
288-
format_glmnet_multi_logistic_reg(
289-
pred,
290-
penalty,
291-
type = dots$type,
292-
lvl = object$lvl
293-
)
294-
}
216+
multi_predict._lognet <- multi_predict_glmnet
295217

296218
format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
297219
param_key <- tibble(group = colnames(pred), penalty = penalty)
@@ -324,32 +246,13 @@ format_glmnet_multi_logistic_reg <- function(pred, penalty, type, lvl) {
324246

325247

326248
#' @export
327-
predict_class._lognet <- function(object, new_data, ...) {
328-
if (any(names(enquos(...)) == "newdata"))
329-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
330-
331-
object$spec <- eval_args(object$spec)
332-
predict_class.model_fit(object, new_data = new_data, ...)
333-
}
249+
predict_class._lognet <- predict_class_glmnet
334250

335251
#' @export
336-
predict_classprob._lognet <- function(object, new_data, ...) {
337-
if (any(names(enquos(...)) == "newdata"))
338-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
339-
340-
object$spec <- eval_args(object$spec)
341-
predict_classprob.model_fit(object, new_data = new_data, ...)
342-
}
252+
predict_classprob._lognet <- predict_classprob_glmnet
343253

344254
#' @export
345-
predict_raw._lognet <- function(object, new_data, opts = list(), ...) {
346-
if (any(names(enquos(...)) == "newdata"))
347-
rlang::abort("Did you mean to use `new_data` instead of `newdata`?")
348-
349-
object$spec <- eval_args(object$spec)
350-
opts$s <- object$spec$args$penalty
351-
predict_raw.model_fit(object, new_data = new_data, opts = opts, ...)
352-
}
255+
predict_raw._lognet <- predict_raw_glmnet
353256

354257
# ------------------------------------------------------------------------------
355258

0 commit comments

Comments
 (0)