Skip to content

Commit 53722db

Browse files
authored
Merge pull request #382 from tidymodels/glmnet-column-fix
Reorder columns at prediction time for glmnet
2 parents 0c23061 + 3419f0e commit 53722db

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

R/linear_reg_data.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ set_pred(
168168
args =
169169
list(
170170
object = expr(object$fit),
171-
newx = expr(as.matrix(new_data)),
171+
newx = expr(as.matrix(new_data[, rownames(object$fit$beta)])),
172172
type = "response",
173173
s = expr(object$spec$args$penalty)
174174
)

R/logistic_reg_data.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,10 +186,10 @@ set_pred(
186186
func = c(fun = "predict"),
187187
args =
188188
list(
189-
object = quote(object$fit),
190-
newx = quote(as.matrix(new_data)),
189+
object = expr(object$fit),
190+
newx = expr(as.matrix(new_data[, rownames(object$fit$beta)])),
191191
type = "response",
192-
s = quote(object$spec$args$penalty)
192+
s = expr(object$spec$args$penalty)
193193
)
194194
)
195195
)

R/multinom_reg_data.R

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ set_pred(
6161
args =
6262
list(
6363
object = quote(object$fit),
64-
newx = quote(as.matrix(new_data)),
64+
newx = quote(as.matrix(new_data[, rownames(object$fit$beta[[1]])])),
6565
type = "class",
6666
s = quote(object$spec$args$penalty)
6767
)
@@ -79,10 +79,10 @@ set_pred(
7979
func = c(fun = "predict"),
8080
args =
8181
list(
82-
object = quote(object$fit),
83-
newx = quote(as.matrix(new_data)),
82+
object = expr(object$fit),
83+
newx = expr(as.matrix(new_data[, rownames(object$fit$beta[[1]])])),
8484
type = "response",
85-
s = quote(object$spec$args$penalty)
85+
s = expr(object$spec$args$penalty)
8686
)
8787
)
8888
)

0 commit comments

Comments
 (0)