Skip to content

Commit 8deb255

Browse files
authored
Merge pull request #415 from tidymodels/augment
augment method for model_fit objects
2 parents de52eb6 + 4d70897 commit 8deb255

File tree

9 files changed

+243
-12
lines changed

9 files changed

+243
-12
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ Imports:
2222
purrr,
2323
utils,
2424
tibble (>= 2.1.1),
25-
generics,
25+
generics (>= 0.1.0),
2626
glue,
2727
magrittr,
2828
stats,

NAMESPACE

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(augment,model_fit)
34
S3method(fit,model_spec)
45
S3method(fit_xy,model_spec)
56
S3method(glance,model_fit)
@@ -103,6 +104,7 @@ export(.x)
103104
export(.y)
104105
export(C5.0_train)
105106
export(add_rowindex)
107+
export(augment)
106108
export(boost_tree)
107109
export(check_empty_ellipse)
108110
export(check_final_param)
@@ -208,6 +210,7 @@ importFrom(dplyr,starts_with)
208210
importFrom(dplyr,summarise)
209211
importFrom(dplyr,tally)
210212
importFrom(dplyr,vars)
213+
importFrom(generics,augment)
211214
importFrom(generics,fit)
212215
importFrom(generics,fit_xy)
213216
importFrom(generics,glance)

R/augment.R

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#' Augment data with predictions
2+
#'
3+
#' `augment()` will add column(s) for predictions to the given data.
4+
#'
5+
#' For regression models, a `.pred` column is added. If `x` was created using
6+
#' [fit()] and `new_data` contains the outcome column, a `.resid` column is
7+
#' also added.
8+
#'
9+
#' For classification models, the results include a column called `.pred_class`
10+
#' as well as class probability columns named `.pred_{level}`.
11+
#' @param x A `model_fit` object produced by [fit()] or [fit_xy()].
12+
#' @param new_data A data frame or matrix.
13+
#' @param ... Not currently used.
14+
#' @export
15+
#' @examples
16+
#' car_trn <- mtcars[11:32,]
17+
#' car_tst <- mtcars[ 1:10,]
18+
#'
19+
#' reg_form <-
20+
#' linear_reg() %>%
21+
#' set_engine("lm") %>%
22+
#' fit(mpg ~ ., data = car_trn)
23+
#' reg_xy <-
24+
#' linear_reg() %>%
25+
#' set_engine("lm") %>%
26+
#' fit_xy(car_trn[, -1], car_trn$mpg)
27+
#'
28+
#' augment(reg_form, car_tst)
29+
#' augment(reg_form, car_tst[, -1])
30+
#'
31+
#' augment(reg_xy, car_tst)
32+
#' augment(reg_xy, car_tst[, -1])
33+
#'
34+
#' # ------------------------------------------------------------------------------
35+
#'
36+
#' data(two_class_dat, package = "modeldata")
37+
#' cls_trn <- two_class_dat[-(1:10), ]
38+
#' cls_tst <- two_class_dat[ 1:10 , ]
39+
#'
40+
#' cls_form <-
41+
#' logistic_reg() %>%
42+
#' set_engine("glm") %>%
43+
#' fit(Class ~ ., data = cls_trn)
44+
#' cls_xy <-
45+
#' logistic_reg() %>%
46+
#' set_engine("glm") %>%
47+
#' fit_xy(cls_trn[, -3],
48+
#' cls_trn$Class)
49+
#'
50+
#' augment(cls_form, cls_tst)
51+
#' augment(cls_form, cls_tst[, -3])
52+
#'
53+
#' augment(cls_xy, cls_tst)
54+
#' augment(cls_xy, cls_tst[, -3])
55+
#'
56+
augment.model_fit <- function(x, new_data, ...) {
57+
if (x$spec$mode == "regression") {
58+
new_data <-
59+
new_data %>%
60+
dplyr::bind_cols(
61+
predict(x, new_data = new_data)
62+
)
63+
if (length(x$preproc$y_var) > 0) {
64+
y_nm <- x$preproc$y_var
65+
if (any(names(new_data) == y_nm)) {
66+
new_data <- dplyr::mutate(new_data, .resid = !!rlang::sym(y_nm) - .pred)
67+
}
68+
}
69+
} else if (x$spec$mode == "classification") {
70+
new_data <-
71+
new_data %>%
72+
dplyr::bind_cols(
73+
predict(x, new_data = new_data, type = "class"),
74+
predict(x, new_data = new_data, type = "prob")
75+
)
76+
} else {
77+
rlang::abort(paste("Unknown mode:", x$spec$mode))
78+
}
79+
new_data
80+
}

R/reexports.R

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ generics::fit_xy
1515
#' @export
1616
generics::tidy
1717

18-
1918
#' @importFrom generics glance
2019
#' @export
2120
generics::glance
21+
22+
#' @importFrom generics augment
23+
#' @export
24+
generics::augment

R/zzz.R

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
# nocov start
22

33
.onLoad <- function(libname, pkgname) {
4-
s3_register("broom::tidy", "model_fit")
5-
s3_register("broom::tidy", "nullmodel")
6-
s3_register("broom::tidy", "_elnet")
7-
s3_register("broom::tidy", "_lognet")
8-
s3_register("broom::tidy", "_multnet")
9-
s3_register("broom::tidy", "_fishnet")
10-
s3_register("broom::glance", "model_fit")
4+
s3_register("generics::tidy", "model_fit")
5+
s3_register("generics::tidy", "nullmodel")
6+
s3_register("generics::tidy", "_elnet")
7+
s3_register("generics::tidy", "_lognet")
8+
s3_register("generics::tidy", "_multnet")
9+
s3_register("generics::tidy", "_fishnet")
10+
s3_register("generics::glance", "model_fit")
11+
s3_register("generics::augment", "model_fit")
1112
}
1213

1314

_pkgdown.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ reference:
4949
- set_mode
5050
- tidy.model_fit
5151
- translate
52-
- varying
53-
- varying_args
52+
- augment.model.fit
5453

5554
- title: Developer Tools
5655
contents:

man/augment.model_fit.Rd

Lines changed: 68 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

man/reexports.Rd

Lines changed: 2 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-augment.R

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
2+
context("augment")
3+
4+
# ------------------------------------------------------------------------------
5+
6+
test_that('regression models', {
7+
x <- linear_reg() %>% set_engine("lm")
8+
9+
reg_form <- x %>% fit(mpg ~ ., data = mtcars)
10+
reg_xy <- x %>% fit_xy(mtcars[, -1], mtcars$mpg)
11+
12+
expect_equal(
13+
colnames(augment(reg_form, head(mtcars))),
14+
c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
15+
"gear", "carb", ".pred", ".resid")
16+
)
17+
expect_equal(nrow(augment(reg_form, head(mtcars))), 6)
18+
expect_equal(
19+
colnames(augment(reg_form, head(mtcars[, -1]))),
20+
c("cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
21+
"gear", "carb", ".pred")
22+
)
23+
expect_equal(nrow(augment(reg_form, head(mtcars[, -1]))), 6)
24+
25+
expect_equal(
26+
colnames(augment(reg_xy, head(mtcars))),
27+
c("mpg", "cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
28+
"gear", "carb", ".pred")
29+
)
30+
expect_equal(nrow(augment(reg_xy, head(mtcars))), 6)
31+
expect_equal(
32+
colnames(augment(reg_xy, head(mtcars[, -1]))),
33+
c("cyl", "disp", "hp", "drat", "wt", "qsec", "vs", "am",
34+
"gear", "carb", ".pred")
35+
)
36+
expect_equal(nrow(augment(reg_xy, head(mtcars[, -1]))), 6)
37+
38+
reg_form$spec$mode <- "depeche"
39+
40+
expect_error(augment(reg_form, head(mtcars[, -1])), "Unknown mode: depeche")
41+
42+
})
43+
44+
45+
46+
test_that('classification models', {
47+
data(two_class_dat, package = "modeldata")
48+
x <- logistic_reg() %>% set_engine("glm")
49+
50+
cls_form <- x %>% fit(Class ~ ., data = two_class_dat)
51+
cls_xy <- x %>% fit_xy(two_class_dat[, -3], two_class_dat$Class)
52+
53+
expect_equal(
54+
colnames(augment(cls_form, head(two_class_dat))),
55+
c("A", "B", "Class", ".pred_class", ".pred_Class1", ".pred_Class2")
56+
)
57+
expect_equal(nrow(augment(cls_form, head(two_class_dat))), 6)
58+
expect_equal(
59+
colnames(augment(cls_form, head(two_class_dat[, -3]))),
60+
c("A", "B", ".pred_class", ".pred_Class1", ".pred_Class2")
61+
)
62+
expect_equal(nrow(augment(cls_form, head(two_class_dat[, -3]))), 6)
63+
64+
expect_equal(
65+
colnames(augment(cls_xy, head(two_class_dat))),
66+
c("A", "B", "Class", ".pred_class", ".pred_Class1", ".pred_Class2")
67+
)
68+
expect_equal(nrow(augment(cls_xy, head(two_class_dat))), 6)
69+
expect_equal(
70+
colnames(augment(cls_xy, head(two_class_dat[, -3]))),
71+
c("A", "B", ".pred_class", ".pred_Class1", ".pred_Class2")
72+
)
73+
expect_equal(nrow(augment(cls_xy, head(two_class_dat[, -3]))), 6)
74+
75+
})
76+

0 commit comments

Comments
 (0)