Skip to content

Commit b2b054d

Browse files
committed
Use helper functions to check whether model supports class / prob predictions
1 parent 1e525bb commit b2b054d

File tree

1 file changed

+17
-8
lines changed

1 file changed

+17
-8
lines changed

R/augment.R

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,28 @@ augment.model_fit <- function(x, new_data, ...) {
6868
}
6969
}
7070
} else if (x$spec$mode == "classification") {
71-
new_data <- dplyr::bind_cols(
72-
new_data,
73-
predict(x, new_data = new_data, type = "class")
74-
)
75-
tryCatch(
71+
if (has_class_preds(x)) {
72+
new_data <- dplyr::bind_cols(
73+
new_data,
74+
predict(x, new_data = new_data, type = "class")
75+
)
76+
}
77+
if (has_class_probs(x)) {
7678
new_data <- dplyr::bind_cols(
7779
new_data,
7880
predict(x, new_data = new_data, type = "prob")
79-
),
80-
error = function(cnd) cnd
81-
)
81+
)
82+
}
8283
} else {
8384
rlang::abort(paste("Unknown mode:", x$spec$mode))
8485
}
8586
as_tibble(new_data)
8687
}
88+
89+
has_class_preds <- function(x) {
90+
any(names(x$spec$method$pred) == "class")
91+
}
92+
93+
has_class_probs <- function(x) {
94+
any(names(x$spec$method$pred) == "prob")
95+
}

0 commit comments

Comments
 (0)