Skip to content

Commit 2afbef6

Browse files
topepoDavisVaughan
andauthored
case weights (#692)
* initial case weight work for xy models * make the env arg of eval_tidy more explicit * conversion of weights class to some numeric type * changes for formula call formation * check to see if the model can use case weights * change envir vector name to weights * update model defs for non-standard case weight arg names * could it possibly be this easy? * better approach to handling model.frame() issues in lm() * no case weights for LiblinearR (they are class weights) * add and update unit tests * version updates and remotes * Apply suggestions from code review Co-authored-by: Davis Vaughan <[email protected]> * function for grouped binomial data * changes based on reviewer feedback * re-export hardhat functions * changes based on reviewer feedback * test non-standard argument names * temp bypass for r-devel * pass case weights to xgboost * update tests for xgboost/boost_tree args * add case weight summary to show_model_info() * added glm_grouped to pkgdown * more unit tests * spark support for case weights * updates to documentation for case weights * add missing topic * more engine doc updates * added more notes in engine docs * added more notes in engine docs * gam weights * doc update * revert nnet case weights * S3 method to convert hardhat format to numeric * Ensure that `fit_xy()` patches the formula environment with weights (#705) * Prefix everywhere we use `new_quosure()` or `empty_env()` We don't import these, so we have to do this. Tests were only working by chance because we have `library(rlang)` in some of the test files! * Ensure that `fit_xy()` patches the formula environment with weights * missing roxygen tag * avoid deprecated tests Co-authored-by: Max Kuhn <[email protected]> * updated for latest roxygen2 * get xgb to stop being so chatty * update snapshots * doc update * missing doc entry * remove convert_case_weights Co-authored-by: Davis Vaughan <[email protected]>
1 parent 02107d0 commit 2afbef6

File tree

256 files changed

+2662
-631
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

256 files changed

+2662
-631
lines changed

DESCRIPTION

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ Depends:
2121
Imports:
2222
cli,
2323
dplyr (>= 0.8.0.1),
24-
generics (>= 0.1.0.9000),
24+
generics (>= 0.1.2),
2525
ggplot2,
2626
globals,
2727
glue,
28-
hardhat (>= 0.1.6.9001),
28+
hardhat (>= 0.2.0.9000),
2929
lifecycle,
3030
magrittr,
3131
prettyunits,
@@ -40,9 +40,8 @@ Imports:
4040
Suggests:
4141
C50,
4242
covr,
43-
dials (>= 0.0.10.9001),
43+
dials (>= 0.1.0),
4444
earth,
45-
tensorflow,
4645
ggrepel,
4746
keras,
4847
kernlab,
@@ -60,30 +59,17 @@ Suggests:
6059
rpart,
6160
sparklyr (>= 1.0.0),
6261
survival,
62+
tensorflow,
6363
testthat (>= 3.0.0),
6464
xgboost (>= 1.5.0.1)
65+
Remotes:
66+
tidymodels/hardhat
6567
VignetteBuilder:
6668
knitr
6769
ByteCompile: true
68-
Config/Needs/website:
69-
C50,
70-
dbarts,
71-
earth,
72-
glmnet,
73-
keras,
74-
kernlab,
75-
kknn,
76-
LiblineaR,
77-
mgcv,
78-
nnet,
79-
parsnip,
80-
randomForest,
81-
ranger,
82-
rpart,
83-
rstanarm,
84-
tidymodels/tidymodels,
85-
tidyverse/tidytemplate,
86-
rstudio/reticulate,
70+
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
71+
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
72+
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
8773
xgboost
8874
Config/rcmdcheck/ignore-inconsequential-notes: true
8975
Encoding: UTF-8

NAMESPACE

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ export(format_linear_pred)
205205
export(format_num)
206206
export(format_survival)
207207
export(format_time)
208+
export(frequency_weights)
208209
export(gen_additive_mod)
209210
export(get_dependency)
210211
export(get_encoding)
@@ -213,7 +214,9 @@ export(get_from_env)
213214
export(get_model_env)
214215
export(get_pred_type)
215216
export(glance)
217+
export(glm_grouped)
216218
export(has_multi_predict)
219+
export(importance_weights)
217220
export(is_varying)
218221
export(keras_mlp)
219222
export(keras_predict_classes)
@@ -333,6 +336,8 @@ importFrom(hardhat,extract_fit_engine)
333336
importFrom(hardhat,extract_parameter_dials)
334337
importFrom(hardhat,extract_parameter_set_dials)
335338
importFrom(hardhat,extract_spec_parsnip)
339+
importFrom(hardhat,frequency_weights)
340+
importFrom(hardhat,importance_weights)
336341
importFrom(hardhat,tune)
337342
importFrom(magrittr,"%>%")
338343
importFrom(purrr,"%||%")

NEWS.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# parsnip (development version)
22

3+
4+
* Enable the use of case weights for models that support them.
5+
6+
* Added a `glm_grouped()` function to convert long data to the grouped format required by `glm()` for logistic regression.
7+
8+
* `show_model_info()` now indicates which models can utilize case weights.
9+
310
* `xgb_train()` now allows for case weights
411

512
* Added `ctree_train()` and `cforest_train()` wrappers for the functions in the partykit package. Engines for these will be added to other parsnip extension packages.
@@ -14,6 +21,7 @@
1421

1522
* Model type functions will now message informatively if a needed parsnip extension package is not loaded (#731).
1623

24+
1725
# parsnip 0.2.1
1826

1927
* Fixed a major bug in spark models induced in the previous version (#671).

R/aaa_models.R

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -927,8 +927,24 @@ show_model_info <- function(model) {
927927
engines <- get_from_env(model)
928928
if (nrow(engines) > 0) {
929929
cat(" engines: \n")
930-
engines %>%
930+
931+
weight_info <-
932+
purrr::map_df(
933+
model,
934+
~ get_from_env(paste0(.x, "_fit")) %>% mutate(model = .x)
935+
) %>%
936+
dplyr::mutate(protect = map(value, ~ .x$protect)) %>%
937+
dplyr::select(-value) %>%
931938
dplyr::mutate(
939+
has_wts = purrr::map_lgl(protect, ~ any(grepl("^weight", .x))),
940+
has_wts = ifelse(has_wts, cli::symbol$sup_1, "")
941+
) %>%
942+
dplyr::select(engine, mode, has_wts)
943+
944+
engines %>%
945+
dplyr::left_join(weight_info, by = c("engine", "mode")) %>%
946+
dplyr::mutate(
947+
engine = paste0(engine, has_wts),
932948
mode = format(paste0(mode, ": "))
933949
) %>%
934950
dplyr::group_by(mode) %>%
@@ -941,7 +957,7 @@ show_model_info <- function(model) {
941957
dplyr::ungroup() %>%
942958
dplyr::pull(lab) %>%
943959
cat(sep = "")
944-
cat("\n")
960+
cat("\n", cli::symbol$sup_1, "The model can use case weights.\n\n", sep = "")
945961
} else {
946962
cat(" no registered engines.\n\n")
947963
}

R/arguments.R

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -150,11 +150,35 @@ make_call <- function(fun, ns, args, ...) {
150150

151151
make_form_call <- function(object, env = NULL) {
152152
fit_args <- object$method$fit$args
153+
uses_weights <- has_weights(env)
153154

154-
# Get the arguments related to data:
155+
# In model specification code using `set_fit()`, there are two main arguments
156+
# that dictate the data-related model arguments (e.g. 'formula', 'data', 'x',
157+
# etc).
158+
# The 'protect' element specifies which data arguments should not be modifiable
159+
# by the user (as an engine argument). These have standardized names that
160+
# follow the usual R conventions. For example, `foo(formula, data, weights)`
161+
# and so on.
162+
# However, some packages do not follow these naming conventions. The 'data'
163+
# element in `set_fit()` allows use to have non-standard argument names by
164+
# providing a named list. If function `bar(f, dat, wts)` was being used, the
165+
# 'data' element would be `c(formula = "f", data = "dat", weights = "wts)`.
166+
# If conventional names are used, there is no 'data' element since the values
167+
# in 'protect' suffice.
168+
169+
# Get the arguments related to data arguments to insert into the model call
170+
171+
# Do we have conventional argument names?
155172
if (is.null(object$method$fit$data)) {
156-
data_args <- c(formula = "formula", data = "data")
173+
# Set the minimum arguments for formula methods.
174+
data_args <- object$method$fit$protect
175+
names(data_args) <- data_args
176+
# Case weights _could_ be used but remove the arg if they are not given:
177+
if (!uses_weights) {
178+
data_args <- data_args[data_args != "weights"]
179+
}
157180
} else {
181+
# What are the non-conventional names?
158182
data_args <- object$method$fit$data
159183
}
160184

@@ -166,6 +190,7 @@ make_form_call <- function(object, env = NULL) {
166190
# sub in actual formula
167191
fit_args[[ unname(data_args["formula"]) ]] <- env$formula
168192

193+
# TODO remove weights col from data?
169194
if (object$engine == "spark") {
170195
env$x <- env$data
171196
}
@@ -178,12 +203,20 @@ make_form_call <- function(object, env = NULL) {
178203
fit_call
179204
}
180205

181-
make_xy_call <- function(object, target) {
206+
# TODO we need something to indicate that case weights are being used.
207+
make_xy_call <- function(object, target, env) {
182208
fit_args <- object$method$fit$args
209+
uses_weights <- has_weights(env)
210+
211+
# See the comments above in make_form_call()
183212

184-
# Get the arguments related to data:
185213
if (is.null(object$method$fit$data)) {
186-
data_args <- c(x = "x", y = "y")
214+
data_args <- object$method$fit$protect
215+
names(data_args) <- data_args
216+
# Case weights _could_ be used but remove the arg if they are not given:
217+
if (!uses_weights) {
218+
data_args <- data_args[data_args != "weights"]
219+
}
187220
} else {
188221
data_args <- object$method$fit$data
189222
}
@@ -197,6 +230,9 @@ make_xy_call <- function(object, target) {
197230
matrix = rlang::expr(maybe_matrix(x)),
198231
rlang::abort(glue::glue("Invalid data type target: {target}."))
199232
)
233+
if (uses_weights) {
234+
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)
235+
}
200236

201237
fit_call <- make_call(
202238
fun = object$method$fit$func["fun"],
@@ -269,3 +305,4 @@ min_rows <- function(num_rows, source, offset = 0) {
269305

270306
as.integer(num_rows)
271307
}
308+

R/boost_tree_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ set_fit(
8282
mode = "regression",
8383
value = list(
8484
interface = "matrix",
85-
protect = c("x", "y"),
85+
protect = c("x", "y", "weights"),
8686
func = c(pkg = "parsnip", fun = "xgb_train"),
8787
defaults = list(nthread = 1, verbose = 0)
8888
)
@@ -132,7 +132,7 @@ set_fit(
132132
mode = "classification",
133133
value = list(
134134
interface = "matrix",
135-
protect = c("x", "y"),
135+
protect = c("x", "y", "weights"),
136136
func = c(pkg = "parsnip", fun = "xgb_train"),
137137
defaults = list(nthread = 1, verbose = 0)
138138
)

R/case_weights.R

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
#' Using case weights with parsnip
2+
#'
3+
#' Case weights are positive numeric values that influence how much each data
4+
#' point has during the model fitting process. There are a variety of situations
5+
#' where case weights can be used.
6+
#'
7+
#' tidymodels packages differentiate _how_ different types of case weights
8+
#' should be used during the entire data analysis process, including
9+
#' preprocessing data, model fitting, performance calculations, etc.
10+
#'
11+
#' The tidymodels packages require users to convert their numeric vectors to a
12+
#' vector class that reflects how these should be used. For example, there are
13+
#' some situations where the weights should not affect operations such as
14+
#' centering and scaling or other preprocessing operations.
15+
#'
16+
#' The types of weights allowed in tidymodels are:
17+
#'
18+
#' * Frequency weights via [hardhat::frequency_weights()]
19+
#' * Importance weights via [hardhat::importance_weights()]
20+
#'
21+
#' More types can be added by request.
22+
#'
23+
#' For parsnip, the [fit()] and [fit_xy] functions contain a `case_weight`
24+
#' argument that takes these data. For Spark models, the argument value should
25+
#' be a character value.
26+
#'
27+
#' @name case_weights
28+
#' @seealso [frequency_weights()], [importance_weights()], [fit()], [fit_xy]
29+
NULL
30+
31+
# ------------------------------------------------------------------------------
32+
33+
weights_to_numeric <- function(x, spec) {
34+
if (is.null(x)) {
35+
return(NULL)
36+
} else if (spec$engine == "spark") {
37+
# Spark wants a column name
38+
return(x)
39+
}
40+
41+
to_int <- c("hardhat_frequency_weights")
42+
if (inherits(x, to_int)) {
43+
x <- as.integer(x)
44+
} else {
45+
x <- as.numeric(x)
46+
}
47+
x
48+
}
49+
50+
patch_formula_environment_with_case_weights <- function(formula,
51+
data,
52+
case_weights) {
53+
# `lm()` and `glm()` and others use the original model function call to
54+
# construct a call for `model.frame()`. That will normally fail because the
55+
# formula has its own environment attached (usually the global environment)
56+
# and it will look there for a vector named 'weights'. To account
57+
# for this, we create a child of the `formula`'s environment and
58+
# stash the `weights` there with the expected name and then
59+
# reassign this as the `formula`'s environment
60+
environment(formula) <- rlang::new_environment(
61+
data = list(data = data, weights = case_weights),
62+
parent = environment(formula)
63+
)
64+
65+
formula
66+
}
67+
68+
# ------------------------------------------------------------------------------
69+
70+
case_weights_allowed <- function(spec) {
71+
mod_type <- class(spec)[1]
72+
mod_eng <- spec$engine
73+
mod_mode <- spec$mode
74+
75+
model_info <-
76+
get_from_env(paste0(mod_type, "_fit")) %>%
77+
dplyr::filter(engine == mod_eng & mode == mod_mode)
78+
if (nrow(model_info) != 1) {
79+
rlang::abort(
80+
glue::glue(
81+
"Error in geting model information for model {mod_type} with engine {mod_eng} and mode {mod_mode}."
82+
)
83+
)
84+
}
85+
# If weights are used, they are protected data arguments with the canonical
86+
# name 'weights' (although this may not be the model function's argument name).
87+
data_args <- model_info$value[[1]]$protect
88+
any(data_args == "weights")
89+
}
90+
91+
has_weights <- function(env) {
92+
!is.null(env$weights)
93+
}

R/convert_data.R

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@
252252
if (length(weights) != nrow(x)) {
253253
rlang::abort(glue::glue("`weights` should have {nrow(x)} elements"))
254254
}
255+
256+
form <- patch_formula_environment_with_case_weights(
257+
formula = form,
258+
data = x,
259+
case_weights = weights
260+
)
255261
}
256262

257263
res <- list(

0 commit comments

Comments
 (0)