Skip to content

Commit 4d70897

Browse files
committed
changes for PR comments by Davis and Julia
1 parent 1f0de51 commit 4d70897

File tree

3 files changed

+43
-27
lines changed

3 files changed

+43
-27
lines changed

R/augment.R

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,40 +13,45 @@
1313
#' @param ... Not currently used.
1414
#' @export
1515
#' @examples
16+
#' car_trn <- mtcars[11:32,]
17+
#' car_tst <- mtcars[ 1:10,]
18+
#'
1619
#' reg_form <-
1720
#' linear_reg() %>%
1821
#' set_engine("lm") %>%
19-
#' fit(mpg ~ ., data = mtcars)
22+
#' fit(mpg ~ ., data = car_trn)
2023
#' reg_xy <-
2124
#' linear_reg() %>%
2225
#' set_engine("lm") %>%
23-
#' fit_xy(mtcars[, -1], mtcars$mpg)
26+
#' fit_xy(car_trn[, -1], car_trn$mpg)
2427
#'
25-
#' augment(reg_form, head(mtcars))
26-
#' augment(reg_form, head(mtcars[, -1]))
28+
#' augment(reg_form, car_tst)
29+
#' augment(reg_form, car_tst[, -1])
2730
#'
28-
#' augment(reg_xy, head(mtcars))
29-
#' augment(reg_xy, head(mtcars[, -1]))
31+
#' augment(reg_xy, car_tst)
32+
#' augment(reg_xy, car_tst[, -1])
3033
#'
3134
#' # ------------------------------------------------------------------------------
3235
#'
3336
#' data(two_class_dat, package = "modeldata")
37+
#' cls_trn <- two_class_dat[-(1:10), ]
38+
#' cls_tst <- two_class_dat[ 1:10 , ]
3439
#'
3540
#' cls_form <-
3641
#' logistic_reg() %>%
3742
#' set_engine("glm") %>%
38-
#' fit(Class ~ ., data = two_class_dat)
43+
#' fit(Class ~ ., data = cls_trn)
3944
#' cls_xy <-
4045
#' logistic_reg() %>%
4146
#' set_engine("glm") %>%
42-
#' fit_xy(two_class_dat[, -3],
43-
#' two_class_dat$Class)
47+
#' fit_xy(cls_trn[, -3],
48+
#' cls_trn$Class)
4449
#'
45-
#' augment(cls_form, head(two_class_dat))
46-
#' augment(cls_form, head(two_class_dat[, -3]))
50+
#' augment(cls_form, cls_tst)
51+
#' augment(cls_form, cls_tst[, -3])
4752
#'
48-
#' augment(cls_xy, head(two_class_dat))
49-
#' augment(cls_xy, head(two_class_dat[, -3]))
53+
#' augment(cls_xy, cls_tst)
54+
#' augment(cls_xy, cls_tst[, -3])
5055
#'
5156
augment.model_fit <- function(x, new_data, ...) {
5257
if (x$spec$mode == "regression") {
@@ -61,13 +66,15 @@ augment.model_fit <- function(x, new_data, ...) {
6166
new_data <- dplyr::mutate(new_data, .resid = !!rlang::sym(y_nm) - .pred)
6267
}
6368
}
64-
} else {
69+
} else if (x$spec$mode == "classification") {
6570
new_data <-
6671
new_data %>%
6772
dplyr::bind_cols(
6873
predict(x, new_data = new_data, type = "class"),
6974
predict(x, new_data = new_data, type = "prob")
7075
)
76+
} else {
77+
rlang::abort(paste("Unknown mode:", x$spec$mode))
7178
}
7279
new_data
7380
}

man/augment.model_fit.Rd

Lines changed: 18 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-augment.R

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ test_that('regression models', {
3535
)
3636
expect_equal(nrow(augment(reg_xy, head(mtcars[, -1]))), 6)
3737

38+
reg_form$spec$mode <- "depeche"
39+
40+
expect_error(augment(reg_form, head(mtcars[, -1])), "Unknown mode: depeche")
41+
3842
})
3943

4044

0 commit comments

Comments
 (0)