Skip to content

Commit 6d6b0a7

Browse files
authored
Merge pull request #487 from tidymodels/try-catch-augment
`augment()` for models without class probabilities
2 parents f364b81 + 426fd3d commit 6d6b0a7

File tree

4 files changed

+38
-10
lines changed

4 files changed

+38
-10
lines changed

R/aaa_models.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,13 @@ check_pred_info <- function(pred_obj, type) {
286286
invisible(NULL)
287287
}
288288

289-
check_spec_pred_type <- function(object, type) {
289+
spec_has_pred_type <- function(object, type) {
290290
possible_preds <- names(object$spec$method$pred)
291-
if (!any(possible_preds == type)) {
291+
any(possible_preds == type)
292+
}
293+
check_spec_pred_type <- function(object, type) {
294+
if (!spec_has_pred_type(object, type)) {
295+
possible_preds <- names(object$spec$method$pred)
292296
rlang::abort(c(
293297
glue::glue("No {type} prediction method available for this model."),
294298
glue::glue("Value for `type` should be one of: ",

R/augment.R

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
#' [fit()] and `new_data` contains the outcome column, a `.resid` column is
77
#' also added.
88
#'
9-
#' For classification models, the results include a column called `.pred_class`
10-
#' as well as class probability columns named `.pred_{level}`.
9+
#' For classification models, the results can include a column called
10+
#' `.pred_class` as well as class probability columns named `.pred_{level}`.
11+
#' This depends on what type of prediction types are available for the model.
1112
#' @param x A `model_fit` object produced by [fit()] or [fit_xy()].
1213
#' @param new_data A data frame or matrix.
1314
#' @param ... Not currently used.
@@ -56,6 +57,7 @@
5657
#'
5758
augment.model_fit <- function(x, new_data, ...) {
5859
if (x$spec$mode == "regression") {
60+
check_spec_pred_type(x, "numeric")
5961
new_data <-
6062
new_data %>%
6163
dplyr::bind_cols(
@@ -68,12 +70,18 @@ augment.model_fit <- function(x, new_data, ...) {
6870
}
6971
}
7072
} else if (x$spec$mode == "classification") {
71-
new_data <-
72-
new_data %>%
73-
dplyr::bind_cols(
74-
predict(x, new_data = new_data, type = "class"),
73+
if (spec_has_pred_type(x, "class")) {
74+
new_data <- dplyr::bind_cols(
75+
new_data,
76+
predict(x, new_data = new_data, type = "class")
77+
)
78+
}
79+
if (spec_has_pred_type(x, "prob")) {
80+
new_data <- dplyr::bind_cols(
81+
new_data,
7582
predict(x, new_data = new_data, type = "prob")
7683
)
84+
}
7785
} else {
7886
rlang::abort(paste("Unknown mode:", x$spec$mode))
7987
}

man/augment.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-augment.R

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,18 @@ test_that('classification models', {
7676

7777
})
7878

79+
80+
test_that('augment for model without class probabilities', {
81+
skip_if_not_installed("LiblineaR")
82+
83+
data(two_class_dat, package = "modeldata")
84+
x <- svm_linear(mode = "classification") %>% set_engine("LiblineaR")
85+
cls_form <- x %>% fit(Class ~ ., data = two_class_dat)
86+
87+
expect_equal(
88+
colnames(augment(cls_form, head(two_class_dat))),
89+
c("A", "B", "Class", ".pred_class")
90+
)
91+
expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6)
92+
93+
})

0 commit comments

Comments
 (0)