6
6
# ' [fit()] and `new_data` contains the outcome column, a `.resid` column is
7
7
# ' also added.
8
8
# '
9
- # ' For classification models, the results include a column called `.pred_class`
10
- # ' as well as class probability columns named `.pred_{level}`.
9
+ # ' For classification models, the results can include a column called
10
+ # ' `.pred_class` as well as class probability columns named `.pred_{level}`.
11
+ # ' This depends on what type of prediction types are available for the model.
11
12
# ' @param x A `model_fit` object produced by [fit()] or [fit_xy()].
12
13
# ' @param new_data A data frame or matrix.
13
14
# ' @param ... Not currently used.
56
57
# '
57
58
augment.model_fit <- function (x , new_data , ... ) {
58
59
if (x $ spec $ mode == " regression" ) {
60
+ check_spec_pred_type(x , " numeric" )
59
61
new_data <-
60
62
new_data %> %
61
63
dplyr :: bind_cols(
@@ -68,13 +70,13 @@ augment.model_fit <- function(x, new_data, ...) {
68
70
}
69
71
}
70
72
} else if (x $ spec $ mode == " classification" ) {
71
- if (has_class_preds( x )) {
73
+ if (spec_has_pred_type( x , " class " )) {
72
74
new_data <- dplyr :: bind_cols(
73
75
new_data ,
74
76
predict(x , new_data = new_data , type = " class" )
75
77
)
76
78
}
77
- if (has_class_probs( x )) {
79
+ if (spec_has_pred_type( x , " prob " )) {
78
80
new_data <- dplyr :: bind_cols(
79
81
new_data ,
80
82
predict(x , new_data = new_data , type = " prob" )
@@ -85,11 +87,3 @@ augment.model_fit <- function(x, new_data, ...) {
85
87
}
86
88
as_tibble(new_data )
87
89
}
88
-
89
- has_class_preds <- function (x ) {
90
- any(names(x $ spec $ method $ pred ) == " class" )
91
- }
92
-
93
- has_class_probs <- function (x ) {
94
- any(names(x $ spec $ method $ pred ) == " prob" )
95
- }
0 commit comments