Skip to content

Commit 1d668ff

Browse files
committed
TST that predictions from predict and multi_predict have same format and all levels
1 parent 86330e6 commit 1d668ff

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

tests/testthat/test_multinom_reg.R

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,14 @@ test_that('bad input', {
104104
expect_error(translate(multinom_reg() %>% set_engine()))
105105
expect_warning(translate(multinom_reg() %>% set_engine("glmnet", x = iris[,1:3], y = iris$Species)))
106106
})
107+
108+
test_that("predictions are factors with all levels", {
109+
basic <- multinom_reg() %>% set_engine("glmnet") %>% fit(Species ~ ., data = iris)
110+
nd <- iris[iris$Species == "setosa", ]
111+
yhat <- predict(basic, new_data = nd, penalty = .1)
112+
expect_is(yhat$.pred_class, "factor")
113+
expect_equal(levels(yhat$.pred_class), levels(iris$Species))
114+
yhat_multi <- multi_predict(basic, new_data = nd, penalty = .1)$.pred
115+
expect_is(yhat_multi[[1]]$.pred_class, "factor")
116+
expect_equal(levels(yhat_multi[[1]]$.pred_class), levels(iris$Species))
117+
})

0 commit comments

Comments
 (0)