Skip to content

Commit 1f0de51

Browse files
committed
augment method for #401
1 parent 0b42c9c commit 1f0de51

File tree

5 files changed

+212
-2
lines changed

5 files changed

+212
-2
lines changed

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(augment,model_fit)
34
S3method(fit,model_spec)
45
S3method(fit_xy,model_spec)
56
S3method(glance,model_fit)
@@ -103,6 +104,7 @@ export(.x)
103104
export(.y)
104105
export(C5.0_train)
105106
export(add_rowindex)
107+
export(augment)
106108
export(boost_tree)
107109
export(check_empty_ellipse)
108110
export(check_final_param)
@@ -208,6 +210,7 @@ importFrom(dplyr,starts_with)
208210
importFrom(dplyr,summarise)
209211
importFrom(dplyr,tally)
210212
importFrom(dplyr,vars)
213+
importFrom(generics,augment)
211214
importFrom(generics,fit)
212215
importFrom(generics,fit_xy)
213216
importFrom(generics,glance)

R/augment.R

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#' Augment data with predictions
2+
#'
3+
#' `augment()` will add column(s) for predictions to the given data.
4+
#'
5+
#' For regression models, a `.pred` column is added. If `x` was created using
6+
#' [fit()] and `new_data` contains the outcome column, a `.resid` column is
7+
#' also added.
8+
#'
9+
#' For classification models, the results include a column called `.pred_class`
10+
#' as well as class probability columns named `.pred_{level}`.
11+
#' @param x A `model_fit` object produced by [fit()] or [fit_xy()].
12+
#' @param new_data A data frame or matrix.
13+
#' @param ... Not currently used.
14+
#' @export
15+
#' @examples
16+
#' reg_form <-
17+
#' linear_reg() %>%
18+
#' set_engine("lm") %>%
19+
#' fit(mpg ~ ., data = mtcars)
20+
#' reg_xy <-
21+
#' linear_reg() %>%
22+
#' set_engine("lm") %>%
23+
#' fit_xy(mtcars[, -1], mtcars$mpg)
24+
#'
25+
#' augment(reg_form, head(mtcars))
26+
#' augment(reg_form, head(mtcars[, -1]))
27+
#'
28+
#' augment(reg_xy, head(mtcars))
29+
#' augment(reg_xy, head(mtcars[, -1]))
30+
#'
31+
#' # ------------------------------------------------------------------------------
32+
#'
33+
#' data(two_class_dat, package = "modeldata")
34+
#'
35+
#' cls_form <-
36+
#' logistic_reg() %>%
37+
#' set_engine("glm") %>%
38+
#' fit(Class ~ ., data = two_class_dat)
39+
#' cls_xy <-
40+
#' logistic_reg() %>%
41+
#' set_engine("glm") %>%
42+
#' fit_xy(two_class_dat[, -3],
43+
#' two_class_dat$Class)
44+
#'
45+
#' augment(cls_form, head(two_class_dat))
46+
#' augment(cls_form, head(two_class_dat[, -3]))
47+
#'
48+
#' augment(cls_xy, head(two_class_dat))
49+
#' augment(cls_xy, head(two_class_dat[, -3]))
50+
#'
51+
augment.model_fit <- function(x, new_data, ...) {
52+
if (x$spec$mode == "regression") {
53+
new_data <-
54+
new_data %>%
55+
dplyr::bind_cols(
56+
predict(x, new_data = new_data)
57+
)
58+
if (length(x$preproc$y_var) > 0) {
59+
y_nm <- x$preproc$y_var
60+
if (any(names(new_data) == y_nm)) {
61+
new_data <- dplyr::mutate(new_data, .resid = !!rlang::sym(y_nm) - .pred)
62+
}
63+
}
64+
} else {
65+
new_data <-
66+
new_data %>%
67+
dplyr::bind_cols(
68+
predict(x, new_data = new_data, type = "class"),
69+
predict(x, new_data = new_data, type = "prob")
70+
)
71+
}
72+
new_data
73+
}

_pkgdown.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ reference:
4949
- set_mode
5050
- tidy.model_fit
5151
- translate
52-
- varying
53-
- varying_args
52+
- augment.model.fit
5453

5554
- title: Developer Tools
5655
contents:

man/augment.model_fit.Rd

Lines changed: 63 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-augment.R

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
2+
context("augment")
3+
4+
# ------------------------------------------------------------------------------
5+
6+
test_that('regression models', {
7+
x <- linear_reg() %>% set_engine("lm")
8+
9+
reg_form <- x %>% fit(mpg ~ ., data = mtcars)
10+
reg_xy <- x %>% fit_xy(mtcars[, -1], mtcars$mpg)
11+
12+
expect_equal(
13+
colnames(augment(reg_form, head(mtcars))),
14+
c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
15+
"gear", "carb", ".pred", ".resid")
16+
)
17+
expect_equal(nrow(augment(reg_form, head(mtcars))), 6)
18+
expect_equal(
19+
colnames(augment(reg_form, head(mtcars[, -1]))),
20+
c("cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
21+
"gear", "carb", ".pred")
22+
)
23+
expect_equal(nrow(augment(reg_form, head(mtcars[, -1]))), 6)
24+
25+
expect_equal(
26+
colnames(augment(reg_xy, head(mtcars))),
27+
c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
28+
"gear", "carb", ".pred")
29+
)
30+
expect_equal(nrow(augment(reg_xy, head(mtcars))), 6)
31+
expect_equal(
32+
colnames(augment(reg_xy, head(mtcars[, -1]))),
33+
c("cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
34+
"gear", "carb", ".pred")
35+
)
36+
expect_equal(nrow(augment(reg_xy, head(mtcars[, -1]))), 6)
37+
38+
})
39+
40+
41+
42+
test_that('classification models', {
43+
data(two_class_dat, package = "modeldata")
44+
x <- logistic_reg() %>% set_engine("glm")
45+
46+
cls_form <- x %>% fit(Class ~ ., data = two_class_dat)
47+
cls_xy <- x %>% fit_xy(two_class_dat[, -3], two_class_dat$Class)
48+
49+
expect_equal(
50+
colnames(augment(cls_form, head(two_class_dat))),
51+
c("A", "B", "Class", ".pred_class", ".pred_Class1", ".pred_Class2")
52+
)
53+
expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6)
54+
expect_equal(
55+
colnames(augment(cls_form, head(two_class_dat[, -3]))),
56+
c("A", "B", ".pred_class", ".pred_Class1", ".pred_Class2")
57+
)
58+
expect_equal(nrow(augment(cls_form, head(two_class_dat[, -3]))), 6)
59+
60+
expect_equal(
61+
colnames(augment(cls_xy, head(two_class_dat))),
62+
c("A", "B", "Class", ".pred_class", ".pred_Class1", ".pred_Class2")
63+
)
64+
expect_equal(nrow(augment(cls_xy, head(two_class_dat))), 6)
65+
expect_equal(
66+
colnames(augment(cls_xy, head(two_class_dat[, -3]))),
67+
c("A", "B", ".pred_class", ".pred_Class1", ".pred_Class2")
68+
)
69+
expect_equal(nrow(augment(cls_xy, head(two_class_dat[, -3]))), 6)
70+
71+
})
72+

0 commit comments

Comments
 (0)