Skip to content

Commit 846bcea

Browse files
committed
Check the pred type in unified way across types
1 parent 408bedc commit 846bcea

10 files changed

+12
-22
lines changed

R/predict_class.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ predict_class.model_fit <- function(object, new_data, ...) {
1212
if (object$spec$mode != "classification")
1313
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
1414

15-
if (!any(names(object$spec$method$pred) == "class"))
16-
rlang::abort("No class prediction module defined for this model.")
15+
check_spec_pred_type(object, "class")
1716

1817
if (inherits(object$fit, "try-error")) {
1918
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_classprob.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ predict_classprob.model_fit <- function(object, new_data, ...) {
99
if (object$spec$mode != "classification")
1010
rlang::abort("`predict.model_fit()` is for predicting factor outcomes.")
1111

12-
if (!any(names(object$spec$method$pred) == "prob"))
13-
rlang::abort("No class probability module defined for this model.")
12+
check_spec_pred_type(object, "prob")
13+
1414

1515
if (inherits(object$fit, "try-error")) {
1616
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_hazard.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
predict_hazard.model_fit <-
88
function(object, new_data, .time, ...) {
99

10-
if (is.null(object$spec$method$pred$hazard))
11-
rlang::abort("No hazard prediction method defined for this engine.")
10+
check_spec_pred_type(object, "hazard")
1211

1312
if (inherits(object$fit, "try-error")) {
1413
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_interval.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
#' @export
1111
predict_confint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) {
1212

13-
if (is.null(object$spec$method$pred$conf_int))
14-
rlang::abort("No confidence interval method defined for this engine.")
13+
check_spec_pred_type(object, "conf_int")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")
@@ -58,8 +57,7 @@ predict_confint <- function(object, ...)
5857
# @export
5958
predict_predint.model_fit <- function(object, new_data, level = 0.95, std_error = FALSE, ...) {
6059

61-
if (is.null(object$spec$method$pred$pred_int))
62-
rlang::abort("No prediction interval method defined for this engine.")
60+
check_spec_pred_type(object, "pred_int")
6361

6462
if (inherits(object$fit, "try-error")) {
6563
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_linear_pred.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
#' @export
77
predict_linear_pred.model_fit <- function(object, new_data, ...) {
88

9-
if (!any(names(object$spec$method$pred) == "linear_pred"))
10-
rlang::abort("No prediction module defined for this model.")
9+
check_spec_pred_type(object, "linear_pred")
1110

1211
if (inherits(object$fit, "try-error")) {
1312
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_numeric.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ predict_numeric.model_fit <- function(object, new_data, ...) {
1010
"Use `predict_class()` or `predict_classprob()` for ",
1111
"classification models."))
1212

13-
if (!any(names(object$spec$method$pred) == "numeric"))
14-
rlang::abort("No prediction module defined for this model.")
13+
check_spec_pred_type(object, "numeric")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_quantile.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
predict_quantile.model_fit <-
1010
function(object, new_data, quantile = (1:9)/10, ...) {
1111

12-
if (is.null(object$spec$method$pred$quantile))
13-
rlang::abort("No quantile prediction method defined for this engine.")
12+
check_spec_pred_type(object, "quantile")
1413

1514
if (inherits(object$fit, "try-error")) {
1615
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_raw.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ predict_raw.model_fit <- function(object, new_data, opts = list(), ...) {
1313
c(object$spec$method$pred$raw$args, opts)
1414
}
1515

16-
if (!any(names(object$spec$method$pred) == "raw"))
17-
rlang::abort("No raw prediction module defined for this model.")
16+
check_spec_pred_type(object, "raw")
1817

1918
if (inherits(object$fit, "try-error")) {
2019
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_survival.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
predict_survival.model_fit <-
88
function(object, new_data, .time, ...) {
99

10-
if (is.null(object$spec$method$pred$survival))
11-
rlang::abort("No survival prediction method defined for this engine.")
10+
check_spec_pred_type(object, "survival")
1211

1312
if (inherits(object$fit, "try-error")) {
1413
rlang::warn("Model fit failed; cannot make predictions.")

R/predict_time.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ predict_time.model_fit <- function(object, new_data, ...) {
1010
"Use `predict_class()` or `predict_classprob()` for ",
1111
"classification models."))
1212

13-
if (!any(names(object$spec$method$pred) == "time"))
14-
rlang::abort("No prediction module defined for this model.")
13+
check_spec_pred_type(object, "time")
1514

1615
if (inherits(object$fit, "try-error")) {
1716
rlang::warn("Model fit failed; cannot make predictions.")

0 commit comments

Comments
 (0)