Skip to content

Commit e8c9277

Browse files
committed
FIX: multinom_reg now outputs factor predictions consistently
1 parent e039f07 commit e8c9277

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

R/multinom_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ multi_predict._multnet <-
290290
pred <-
291291
tibble(
292292
.row = rep(1:nrow(new_data), length(penalty)),
293-
.pred_class = as.vector(pred),
293+
.pred_class = factor(as.vector(pred)),
294294
penalty = rep(penalty, each = nrow(new_data))
295295
)
296296
}

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ test_that('glmnet probabilities, mulitiple lambda', {
121121
multi_predict(xy_fit, iris[rows, 1:4], penalty = lams, type = "prob")$.pred
122122
)
123123

124-
mult_class <- names(mult_probs)[apply(mult_probs, 1, which.max)]
124+
mult_class <- factor(names(mult_probs)[apply(mult_probs, 1, which.max)])
125125
mult_class <- tibble(
126126
.pred_class = mult_class,
127127
penalty = rep(lams, each = 3),

0 commit comments

Comments
 (0)