Skip to content

Commit 2e37391

Browse files
committed
added fix for probability type predictions for two class GAM models
1 parent 80ebaa8 commit 2e37391

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

R/gen_additive_mod_data.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ set_pred(
166166
value = list(
167167
pre = NULL,
168168
post = function(x, object) {
169+
if (is.array(x)) {
170+
x <- as.vector(x)
171+
}
169172
x <- tibble(v1 = 1 - x, v2 = x)
170173
colnames(x) <- object$lvl
171174
x

tests/testthat/test_gen_additive_model.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ test_that('classification', {
8383
mgcv_pred <- predict(mgcv_mod, head(two_class_dat), type = "response")
8484
expect_equal(names(f_pred), c(".pred_Class1", ".pred_Class2"))
8585
expect_equal(f_pred[[".pred_Class2"]], mgcv_pred, ignore_attr = TRUE)
86+
expect_equal(class(f_pred[[".pred_Class1"]]), "numeric")
87+
expect_equal(class(f_pred[[".pred_Class2"]]), "numeric")
8688

8789
f_cls <- predict(f_res, head(two_class_dat), type = "class")
8890
expect_true(all(f_cls$.pred_class[mgcv_pred < 0.5] == "Class1"))

0 commit comments

Comments
 (0)