Skip to content

Commit 7c70d26

Browse files
DavisVaughantopepo
andauthored
Ensure that fit_xy() patches the formula environment with weights (#705)
* Prefix everywhere we use `new_quosure()` or `empty_env()` We don't import these, so we have to do this. Tests were only working by chance because we have `library(rlang)` in some of the test files! * Ensure that `fit_xy()` patches the formula environment with weights * missing roxygen tag * avoid deprecated tests Co-authored-by: Max Kuhn <[email protected]>
1 parent 2aaee25 commit 7c70d26

File tree

7 files changed

+94
-11
lines changed

7 files changed

+94
-11
lines changed

R/case_weights.R

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,24 @@ weights_to_numeric <- function(x, spec) {
4747
x
4848
}
4949

50+
patch_formula_environment_with_case_weights <- function(formula,
51+
data,
52+
case_weights) {
53+
# `lm()` and `glm()` and others use the original model function call to
54+
# construct a call for `model.frame()`. That will normally fail because the
55+
# formula has its own environment attached (usually the global environment)
56+
# and it will look there for a vector named 'weights'. To account
57+
# for this, we create a child of the `formula`'s environment and
58+
# stash the `weights` there with the expected name and then
59+
# reassign this as the `formula`'s environment
60+
environment(formula) <- rlang::new_environment(
61+
data = list(data = data, weights = case_weights),
62+
parent = environment(formula)
63+
)
64+
65+
formula
66+
}
67+
5068
#' Convert case weights to final from
5169
#'
5270
#' tidymodels requires case weights to have special classes. To use them in
@@ -55,6 +73,7 @@ weights_to_numeric <- function(x, spec) {
5573
#' @param x A vector with class `"hardhat_case_weights"`.
5674
#' @param where The location where they will be used: `"parsnip"` or
5775
#' `"yardstick"`.
76+
#' @param ... Additional options (not currently used).
5877
#' @return A numeric vector or NULL.
5978
#' @export
6079
convert_case_weights <- function(x, where = "parsnip", ...) {

R/convert_data.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@
252252
if (length(weights) != nrow(x)) {
253253
rlang::abort(glue::glue("`weights` should have {nrow(x)} elements"))
254254
}
255+
256+
form <- patch_formula_environment_with_case_weights(
257+
formula = form,
258+
data = x,
259+
case_weights = weights
260+
)
255261
}
256262

257263
res <- list(

R/fit.R

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -146,16 +146,10 @@ fit.model_spec <-
146146

147147
wts <- weights_to_numeric(case_weights, object)
148148

149-
# `lm()` and `glm()` and others use the original model function call to
150-
# construct a call for `model.frame()`. That will normally fail because the
151-
# formula has its own environment attached (usually the global environment)
152-
# and it will look there for a vector named 'weights'. To account
153-
# for this, we create a child of the `formula`'s environment and
154-
# stash the `weights` there with the expected name and then
155-
# reassign this as the `formula`'s environment
156-
environment(formula) <- rlang::new_environment(
157-
data = list(data = data, weights = wts),
158-
parent = environment(formula)
149+
formula <- patch_formula_environment_with_case_weights(
150+
formula = formula,
151+
data = data,
152+
case_weights = wts
159153
)
160154

161155
eval_env$data <- data

R/fit_helpers.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ xy_form <- function(object, env, control, ...) {
177177
.convert_xy_to_form_fit(
178178
x = env$x,
179179
y = env$y,
180-
weights = NULL,
180+
weights = env$weights,
181181
y_name = "..y",
182182
remove_intercept = remove_intercept
183183
)

man/convert_case_weights.Rd

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test-case-weights.R

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,25 @@ test_that('case weights with xy method', {
2424
print(C5_bst_wt_fit$fit$call),
2525
"weights = weights"
2626
)
27+
28+
expect_error({
29+
set.seed(1)
30+
C5_bst_wt_fit <-
31+
boost_tree(trees = 5) %>%
32+
set_engine("C5.0") %>%
33+
set_mode("classification") %>%
34+
fit_xy(
35+
x = two_class_dat[c("A", "B")],
36+
y = two_class_dat$Class,
37+
case_weights = wts
38+
)
39+
},
40+
regexp = NA)
41+
42+
expect_output(
43+
print(C5_bst_wt_fit$fit$call),
44+
"weights = weights"
45+
)
2746
})
2847

2948

@@ -51,6 +70,19 @@ test_that('case weights with xy method - non-standard argument names', {
5170
# print(rf_wt_fit$fit$call),
5271
# "case\\.weights = weights"
5372
# )
73+
74+
expect_error({
75+
set.seed(1)
76+
rf_wt_fit <-
77+
rand_forest(trees = 5) %>%
78+
set_mode("classification") %>%
79+
fit_xy(
80+
x = two_class_dat[c("A", "B")],
81+
y = two_class_dat$Class,
82+
case_weights = wts
83+
)
84+
},
85+
regexp = NA)
5486
})
5587

5688
test_that('case weights with formula method', {
@@ -78,5 +110,34 @@ test_that('case weights with formula method', {
78110
expect_equal(coef(lm_wt_fit$fit), coef(lm_sub_fit$fit))
79111
})
80112

113+
test_that('case weights with formula method that goes through `fit_xy()`', {
114+
115+
skip_if_not_installed("modeldata")
116+
data("ames", package = "modeldata")
117+
ames$Sale_Price <- log10(ames$Sale_Price)
118+
119+
set.seed(1)
120+
wts <- runif(nrow(ames))
121+
wts <- ifelse(wts < 1/5, 0L, 1L)
122+
ames_subset <- ames[wts != 0, ]
123+
wts <- frequency_weights(wts)
124+
125+
expect_error(
126+
lm_wt_fit <-
127+
linear_reg() %>%
128+
fit_xy(
129+
x = ames[c("Longitude", "Latitude")],
130+
y = ames$Sale_Price,
131+
case_weights = wts
132+
),
133+
regexp = NA)
81134

135+
lm_sub_fit <-
136+
linear_reg() %>%
137+
fit_xy(
138+
x = ames_subset[c("Longitude", "Latitude")],
139+
y = ames_subset$Sale_Price
140+
)
82141

142+
expect_equal(coef(lm_wt_fit$fit), coef(lm_sub_fit$fit))
143+
})

tests/testthat/test_mlp.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
test_that('updating', {
23
expr1 <- mlp(mode = "regression") %>% set_engine("nnet", Hess = FALSE, abstol = tune())
34
expr2 <- mlp(mode = "regression") %>% set_engine("nnet", Hess = tune())

0 commit comments

Comments
 (0)