Skip to content

Commit aba6b1a

Browse files
Merge pull request #709 from oj713/main
Fixing inconsistency for probability predictions for two class GAM models
2 parents 80ebaa8 + fcfea50 commit aba6b1a

File tree

3 files changed

+9
-0
lines changed

3 files changed

+9
-0
lines changed

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
* Exported `xgb_predict()` which wraps xgboost's `predict()` method for use with parsnip extension packages (#688).
88

9+
## Bug fixes
10+
11+
* An inconsistency for probability type predictions for two-class GAM models was fixed (#708)
12+
913
# parsnip 0.2.1
1014

1115
* Fixed a major bug in spark models induced in the previous version (#671).

R/gen_additive_mod_data.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ set_pred(
166166
value = list(
167167
pre = NULL,
168168
post = function(x, object) {
169+
if (is.array(x)) {
170+
x <- as.vector(x)
171+
}
169172
x <- tibble(v1 = 1 - x, v2 = x)
170173
colnames(x) <- object$lvl
171174
x

tests/testthat/test_gen_additive_model.R

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ test_that('classification', {
8383
mgcv_pred <- predict(mgcv_mod, head(two_class_dat), type = "response")
8484
expect_equal(names(f_pred), c(".pred_Class1", ".pred_Class2"))
8585
expect_equal(f_pred[[".pred_Class2"]], mgcv_pred, ignore_attr = TRUE)
86+
expect_equal(class(f_pred[[".pred_Class1"]]), "numeric")
87+
expect_equal(class(f_pred[[".pred_Class2"]]), "numeric")
8688

8789
f_cls <- predict(f_res, head(two_class_dat), type = "class")
8890
expect_true(all(f_cls$.pred_class[mgcv_pred < 0.5] == "Class1"))

0 commit comments

Comments
 (0)