Skip to content

Commit 86330e6

Browse files
committed
multinom_reg add factor lvls
1 parent 6bed856 commit 86330e6

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

R/multinom_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ multi_predict._multnet <-
290290
pred <-
291291
tibble(
292292
.row = rep(1:nrow(new_data), length(penalty)),
293-
.pred_class = factor(as.vector(pred)),
293+
.pred_class = factor(as.vector(pred), levels = object$lvl),
294294
penalty = rep(penalty, each = nrow(new_data))
295295
)
296296
}

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ test_that('glmnet probabilities, mulitiple lambda', {
121121
multi_predict(xy_fit, iris[rows, 1:4], penalty = lams, type = "prob")$.pred
122122
)
123123

124-
mult_class <- factor(names(mult_probs)[apply(mult_probs, 1, which.max)])
124+
mult_class <- factor(names(mult_probs)[apply(mult_probs, 1, which.max)],
125+
levels = xy_fit$lvl)
125126
mult_class <- tibble(
126127
.pred_class = mult_class,
127128
penalty = rep(lams, each = 3),

0 commit comments

Comments
 (0)