Skip to content

Commit 426fd3d

Browse files
committed
modified checking method for correct pred type
1 parent cc3f892 commit 426fd3d

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

R/aaa_models.R

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -286,9 +286,13 @@ check_pred_info <- function(pred_obj, type) {
286286
invisible(NULL)
287287
}
288288

289-
check_spec_pred_type <- function(object, type) {
289+
spec_has_pred_type <- function(object, type) {
290290
possible_preds <- names(object$spec$method$pred)
291-
if (!any(possible_preds == type)) {
291+
any(possible_preds == type)
292+
}
293+
check_spec_pred_type <- function(object, type) {
294+
if (!spec_has_pred_type(object, type)) {
295+
possible_preds <- names(object$spec$method$pred)
292296
rlang::abort(c(
293297
glue::glue("No {type} prediction method available for this model."),
294298
glue::glue("Value for `type` should be one of: ",

R/augment.R

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
#' [fit()] and `new_data` contains the outcome column, a `.resid` column is
77
#' also added.
88
#'
9-
#' For classification models, the results include a column called `.pred_class`
10-
#' as well as class probability columns named `.pred_{level}`.
9+
#' For classification models, the results can include a column called
10+
#' `.pred_class` as well as class probability columns named `.pred_{level}`.
11+
#' This depends on what type of prediction types are available for the model.
1112
#' @param x A `model_fit` object produced by [fit()] or [fit_xy()].
1213
#' @param new_data A data frame or matrix.
1314
#' @param ... Not currently used.
@@ -56,6 +57,7 @@
5657
#'
5758
augment.model_fit <- function(x, new_data, ...) {
5859
if (x$spec$mode == "regression") {
60+
check_spec_pred_type(x, "numeric")
5961
new_data <-
6062
new_data %>%
6163
dplyr::bind_cols(
@@ -68,13 +70,13 @@ augment.model_fit <- function(x, new_data, ...) {
6870
}
6971
}
7072
} else if (x$spec$mode == "classification") {
71-
if (has_class_preds(x)) {
73+
if (spec_has_pred_type(x, "class")) {
7274
new_data <- dplyr::bind_cols(
7375
new_data,
7476
predict(x, new_data = new_data, type = "class")
7577
)
7678
}
77-
if (has_class_probs(x)) {
79+
if (spec_has_pred_type(x, "prob")) {
7880
new_data <- dplyr::bind_cols(
7981
new_data,
8082
predict(x, new_data = new_data, type = "prob")
@@ -85,11 +87,3 @@ augment.model_fit <- function(x, new_data, ...) {
8587
}
8688
as_tibble(new_data)
8789
}
88-
89-
has_class_preds <- function(x) {
90-
any(names(x$spec$method$pred) == "class")
91-
}
92-
93-
has_class_probs <- function(x) {
94-
any(names(x$spec$method$pred) == "prob")
95-
}

man/augment.Rd

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)