Skip to content

Commit 04eb1b0

Browse files
committed
TST only test nnet_softmax
1 parent 94dbc2a commit 04eb1b0

File tree

3 files changed

+9
-11
lines changed

3 files changed

+9
-11
lines changed

.travis.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ before_install:
5252
- sudo apt-get -y install python3
5353
- mkdir -p ~/.R && echo "CXX14=g++-6" > ~/.R/Makevars
5454
- echo "CXX14FLAGS += -fPIC" >> ~/.R/Makevars
55-
- Rscript -e 'install.packages("nnet")'
5655

5756

5857
after_success:

tests/testthat/test_mlp.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,11 @@ test_that('bad input', {
157157
expect_error(translate(mlp(mode = "regression", formula = y ~ x) %>% set_engine()))
158158
})
159159

160+
test_that("nnet_softmax", {
161+
obj <- mlp(mode = 'classification')
162+
obj$lvls <- c("a", "b")
163+
res <- nnet_softmax(matrix(c(.8, .2)), obj)
164+
expect_equal(names(res), obj$lvls)
165+
expect_equal(res$b, 1 - res$a)
166+
})
167+

tests/testthat/test_predict_formats.R

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,6 @@ lr_fit_2 <-
2828
set_engine("glm") %>%
2929
fit(Ozone ~ ., data = class_dat2)
3030

31-
lr_fit_3 <-
32-
mlp(mode = 'classification') %>%
33-
set_engine("nnet") %>%
34-
fit(Ozone ~ ., data = class_dat2[1:5, ])
35-
36-
3731
# ------------------------------------------------------------------------------
3832

3933
test_that('regression predictions', {
@@ -60,11 +54,8 @@ test_that('non-standard levels', {
6054

6155
expect_true(is_tibble(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob")))
6256
expect_true(is_tibble(parsnip:::predict_classprob.model_fit(lr_fit_2, new_data = class_dat2[1:5,-1])))
63-
final_colnames <- c(".pred_2low", ".pred_high+values")
6457
expect_equal(names(predict(lr_fit_2, new_data = class_dat2[1:5,-1], type = "prob")),
65-
final_colnames)
66-
expect_equal(names(predict(lr_fit_3, new_data = class_dat2, type = 'prob')),
67-
final_colnames)
58+
c(".pred_2low", ".pred_high+values"))
6859
expect_equal(names(parsnip:::predict_classprob.model_fit(lr_fit_2, new_data = class_dat2[1:5,-1])),
6960
c("2low", "high+values"))
7061
})

0 commit comments

Comments
 (0)