Skip to content

Commit 976567f

Browse files
authored
Merge pull request #226 from patr1ckm/multinom_reg_pred
Standardized output types for `multi_predict` and `predict` in `multinom_reg`
2 parents ae66258 + ef56340 commit 976567f

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

NEWS.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@ parsnip model object, and is printed when the model object is printed.
88
* Some default parameter ranges were updated for SVM, KNN, and MARS models.
99

1010
## Fixes
11+
* [A bug](https://github.com/tidymodels/parsnip/issues/222) was fixed standardizing
12+
the output column types of `multi_predict` and `predict` for `multinom_reg`.
1113

1214
* [A bug](https://github.com/tidymodels/parsnip/issues/208) was fixed related to using data descriptors and `fit_xy()`.
1315

1416
* A bug was fixed related to the column names generated by `multi_predict()`. The top-level tibble will always have a column named `.pred` and this list column contains tibbles across sub-models. The column names for these sub-model tibbles will have names consistent with `predict()` (which was previously incorrect). See [43c15db](https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2).
1517

1618
* The model `udpate()` methods gained a `parameters` argument for cases when the parameters are contained in a tibble or list.
1719

18-
# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed standardizing the column names of `nnet` class probability predictions.
20+
* [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed standardizing the column names of `nnet` class probability predictions.
1921

2022

2123
# parsnip 0.0.3.1

R/multinom_reg.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ multi_predict._multnet <-
297297
pred <-
298298
tibble(
299299
.row = rep(1:nrow(new_data), length(penalty)),
300-
.pred_class = as.vector(pred),
300+
.pred_class = factor(as.vector(pred), levels = object$lvl),
301301
penalty = rep(penalty, each = nrow(new_data))
302302
)
303303
}

tests/testthat/test_multinom_reg_glmnet.R

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ test_that('glmnet probabilities, mulitiple lambda', {
121121
multi_predict(xy_fit, iris[rows, 1:4], penalty = lams, type = "prob")$.pred
122122
)
123123

124-
mult_class <- names(mult_probs)[apply(mult_probs, 1, which.max)]
124+
mult_class <- factor(names(mult_probs)[apply(mult_probs, 1, which.max)],
125+
levels = xy_fit$lvl)
125126
mult_class <- tibble(
126127
.pred_class = mult_class,
127128
penalty = rep(lams, each = 3),
@@ -149,3 +150,12 @@ test_that('glmnet probabilities, mulitiple lambda', {
149150
)
150151

151152
})
153+
154+
test_that("class predictions are factors with all levels", {
155+
basic <- multinom_reg() %>% set_engine("glmnet") %>% fit(Species ~ ., data = iris)
156+
nd <- iris[iris$Species == "setosa", ]
157+
yhat <- predict(basic, new_data = nd, penalty = .1)
158+
yhat_multi <- multi_predict(basic, new_data = nd, penalty = .1)$.pred
159+
expect_is(yhat_multi[[1]]$.pred_class, "factor")
160+
expect_equal(levels(yhat_multi[[1]]$.pred_class), levels(iris$Species))
161+
})

0 commit comments

Comments
 (0)