Skip to content

Commit 3fef9c3

Browse files
authored
Merge pull request #225 from patr1ckm/nnet-pred
Standardize nnet class prediction column names
2 parents e039f07 + 04eb1b0 commit 3fef9c3

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66

77
* A bug was fixed related to the column names generated by `multi_predict()`. The top-level tibble will always have a column named `.pred` and this list column contains tibbles across sub-models. The column names for these sub-model tibbles will have names consistent with `predict()` (which was previously incorrect). See [43c15db](https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2).
88

9+
# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed
10+
standardizing the column names of `nnet` class probability predictions.
11+
912
# parsnip 0.0.3.1
1013

1114
Test case update due to CRAN running extra tests [(#202)](https://github.com/tidymodels/parsnip/issues/202)

R/mlp.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ nnet_softmax <- function(results, object) {
381381

382382
results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
383383
results <- t(results)
384-
names(results) <- paste0(".pred_", object$lvl)
384+
colnames(results) <- object$lvl
385385
results <- as_tibble(results)
386386
results
387387
}

tests/testthat/test_mlp.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,11 @@ test_that('bad input', {
157157
expect_error(translate(mlp(mode = "regression", formula = y ~ x) %>% set_engine()))
158158
})
159159

160+
test_that("nnet_softmax", {
161+
obj <- mlp(mode = 'classification')
162+
obj$lvls <- c("a", "b")
163+
res <- nnet_softmax(matrix(c(.8, .2)), obj)
164+
expect_equal(names(res), obj$lvls)
165+
expect_equal(res$b, 1 - res$a)
166+
})
167+

0 commit comments

Comments
 (0)