13
13
# ' @param ... Not currently used.
14
14
# ' @export
15
15
# ' @examples
16
+ # ' car_trn <- mtcars[11:32,]
17
+ # ' car_tst <- mtcars[ 1:10,]
18
+ # '
16
19
# ' reg_form <-
17
20
# ' linear_reg() %>%
18
21
# ' set_engine("lm") %>%
19
- # ' fit(mpg ~ ., data = mtcars )
22
+ # ' fit(mpg ~ ., data = car_trn )
20
23
# ' reg_xy <-
21
24
# ' linear_reg() %>%
22
25
# ' set_engine("lm") %>%
23
- # ' fit_xy(mtcars [, -1], mtcars $mpg)
26
+ # ' fit_xy(car_trn [, -1], car_trn $mpg)
24
27
# '
25
- # ' augment(reg_form, head(mtcars) )
26
- # ' augment(reg_form, head(mtcars [, -1]) )
28
+ # ' augment(reg_form, car_tst )
29
+ # ' augment(reg_form, car_tst [, -1])
27
30
# '
28
- # ' augment(reg_xy, head(mtcars) )
29
- # ' augment(reg_xy, head(mtcars [, -1]) )
31
+ # ' augment(reg_xy, car_tst )
32
+ # ' augment(reg_xy, car_tst [, -1])
30
33
# '
31
34
# ' # ------------------------------------------------------------------------------
32
35
# '
33
36
# ' data(two_class_dat, package = "modeldata")
37
+ # ' cls_trn <- two_class_dat[-(1:10), ]
38
+ # ' cls_tst <- two_class_dat[ 1:10 , ]
34
39
# '
35
40
# ' cls_form <-
36
41
# ' logistic_reg() %>%
37
42
# ' set_engine("glm") %>%
38
- # ' fit(Class ~ ., data = two_class_dat )
43
+ # ' fit(Class ~ ., data = cls_trn )
39
44
# ' cls_xy <-
40
45
# ' logistic_reg() %>%
41
46
# ' set_engine("glm") %>%
42
- # ' fit_xy(two_class_dat [, -3],
43
- # ' two_class_dat $Class)
47
+ # ' fit_xy(cls_trn [, -3],
48
+ # ' cls_trn $Class)
44
49
# '
45
- # ' augment(cls_form, head(two_class_dat) )
46
- # ' augment(cls_form, head(two_class_dat [, -3]) )
50
+ # ' augment(cls_form, cls_tst )
51
+ # ' augment(cls_form, cls_tst [, -3])
47
52
# '
48
- # ' augment(cls_xy, head(two_class_dat) )
49
- # ' augment(cls_xy, head(two_class_dat [, -3]) )
53
+ # ' augment(cls_xy, cls_tst )
54
+ # ' augment(cls_xy, cls_tst [, -3])
50
55
# '
51
56
augment.model_fit <- function (x , new_data , ... ) {
52
57
if (x $ spec $ mode == " regression" ) {
@@ -61,13 +66,15 @@ augment.model_fit <- function(x, new_data, ...) {
61
66
new_data <- dplyr :: mutate(new_data , .resid = !! rlang :: sym(y_nm ) - .pred )
62
67
}
63
68
}
64
- } else {
69
+ } else if ( x $ spec $ mode == " classification " ) {
65
70
new_data <-
66
71
new_data %> %
67
72
dplyr :: bind_cols(
68
73
predict(x , new_data = new_data , type = " class" ),
69
74
predict(x , new_data = new_data , type = " prob" )
70
75
)
76
+ } else {
77
+ rlang :: abort(paste(" Unknown mode:" , x $ spec $ mode ))
71
78
}
72
79
new_data
73
80
}
0 commit comments