Skip to content

case weights #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b927ffe
initial case weight work for xy models
topepo Mar 23, 2022
3bd5f37
make the env arg of eval_tidy more explicit
topepo Mar 23, 2022
73eae0b
conversion of weights class to some numeric type
topepo Mar 23, 2022
1d8507b
changes for formula call formation
topepo Mar 23, 2022
b13af41
check to see if the model can use case weights
topepo Mar 23, 2022
b72c7ce
change envir vector name to weights
topepo Mar 23, 2022
3f5f6a7
update model defs for non-standard case weight arg names
topepo Mar 23, 2022
ee75da5
could it possibly be this easy?
topepo Mar 24, 2022
0391906
better approach to handling model.frame() issues in lm()
topepo Mar 24, 2022
7f4676b
no case weights for LiblinearR (they are class weights)
topepo Mar 24, 2022
6ebce7b
add and update unit tests
topepo Mar 25, 2022
450dd48
version updates and remotes
topepo Mar 25, 2022
3bcab7a
Apply suggestions from code review
topepo Mar 28, 2022
6e036fe
function for grouped binomial data
topepo Mar 28, 2022
b7b5765
changes based on reviewer feedback
topepo Mar 28, 2022
c7cc287
re-export hardhat functions
topepo Mar 28, 2022
f2d0159
changes based on reviewer feedback
topepo Mar 28, 2022
95cd8cd
test non-standard argument names
topepo Mar 28, 2022
2929d53
temp bypass for r-devel
topepo Mar 28, 2022
2100e8f
pass case weights to xgboost
topepo Mar 29, 2022
99d1bc3
update tests for xgboost/boost_tree args
topepo Mar 29, 2022
fe2b184
add case weight summary to show_model_info()
topepo Mar 29, 2022
a4facca
added glm_grouped to pkgdown
topepo Mar 30, 2022
48c30be
more unit tests
topepo Mar 30, 2022
a2b1c1a
spark support for case weights
topepo Mar 30, 2022
118c09d
updates to documentation for case weights
topepo Mar 30, 2022
ca3e6c8
add missing topic
topepo Mar 30, 2022
ceb9c0b
more engine doc updates
topepo Mar 31, 2022
8a6f61c
added more notes in engine docs
topepo Mar 31, 2022
eb81af5
added more notes in engine docs
topepo Mar 31, 2022
69c7c14
gam weights
topepo Mar 31, 2022
33079a3
Merge branch 'main' into feature/case-weights
topepo Apr 12, 2022
c75ed66
doc update
topepo Apr 13, 2022
bc81160
revert nnet case weights
topepo Apr 13, 2022
0698b6d
S3 method to convert hardhat format to numeric
topepo Apr 14, 2022
2aaee25
Merge branch 'main' into feature/case-weights
topepo Apr 21, 2022
7c70d26
Ensure that `fit_xy()` patches the formula environment with weights (…
DavisVaughan Apr 21, 2022
2f18332
updated for latest roxygen2
topepo Apr 21, 2022
68e5c97
get xgb to stop being so chatty
topepo Apr 21, 2022
284252b
update snapshots
topepo Apr 21, 2022
64634a0
Merge branch 'main' into feature/case-weights
topepo May 19, 2022
4cb16cf
doc update
topepo May 19, 2022
f2f24a0
missing doc entry
topepo May 19, 2022
205af75
Merge branch 'main' into feature/case-weights
topepo Jun 2, 2022
a6e7849
remove convert_case_weights
topepo Jun 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 9 additions & 23 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ Depends:
Imports:
cli,
dplyr (>= 0.8.0.1),
generics (>= 0.1.0.9000),
generics (>= 0.1.2),
ggplot2,
globals,
glue,
hardhat (>= 0.1.6.9001),
hardhat (>= 0.2.0.9000),
lifecycle,
magrittr,
prettyunits,
Expand All @@ -40,9 +40,8 @@ Imports:
Suggests:
C50,
covr,
dials (>= 0.0.10.9001),
dials (>= 0.1.0),
earth,
tensorflow,
ggrepel,
keras,
kernlab,
Expand All @@ -60,30 +59,17 @@ Suggests:
rpart,
sparklyr (>= 1.0.0),
survival,
tensorflow,
testthat (>= 3.0.0),
xgboost (>= 1.5.0.1)
Remotes:
tidymodels/hardhat
VignetteBuilder:
knitr
ByteCompile: true
Config/Needs/website:
C50,
dbarts,
earth,
glmnet,
keras,
kernlab,
kknn,
LiblineaR,
mgcv,
nnet,
parsnip,
randomForest,
ranger,
rpart,
rstanarm,
tidymodels/tidymodels,
tidyverse/tidytemplate,
rstudio/reticulate,
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
xgboost
Config/rcmdcheck/ignore-inconsequential-notes: true
Encoding: UTF-8
Expand Down
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ export(format_linear_pred)
export(format_num)
export(format_survival)
export(format_time)
export(frequency_weights)
export(gen_additive_mod)
export(get_dependency)
export(get_encoding)
Expand All @@ -213,7 +214,9 @@ export(get_from_env)
export(get_model_env)
export(get_pred_type)
export(glance)
export(glm_grouped)
export(has_multi_predict)
export(importance_weights)
export(is_varying)
export(keras_mlp)
export(keras_predict_classes)
Expand Down Expand Up @@ -333,6 +336,8 @@ importFrom(hardhat,extract_fit_engine)
importFrom(hardhat,extract_parameter_dials)
importFrom(hardhat,extract_parameter_set_dials)
importFrom(hardhat,extract_spec_parsnip)
importFrom(hardhat,frequency_weights)
importFrom(hardhat,importance_weights)
importFrom(hardhat,tune)
importFrom(magrittr,"%>%")
importFrom(purrr,"%||%")
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# parsnip (development version)


* Enable the use of case weights for models that support them.

* Added a `glm_grouped()` function to convert long data to the grouped format required by `glm()` for logistic regression.

* `show_model_info()` now indicates which models can utilize case weights.

* `xgb_train()` now allows for case weights

* Added `ctree_train()` and `cforest_train()` wrappers for the functions in the partykit package. Engines for these will be added to other parsnip extension packages.
Expand All @@ -14,6 +21,7 @@

* Model type functions will now message informatively if a needed parsnip extension package is not loaded (#731).


# parsnip 0.2.1

* Fixed a major bug in spark models induced in the previous version (#671).
Expand Down
20 changes: 18 additions & 2 deletions R/aaa_models.R
Original file line number Diff line number Diff line change
Expand Up @@ -927,8 +927,24 @@ show_model_info <- function(model) {
engines <- get_from_env(model)
if (nrow(engines) > 0) {
cat(" engines: \n")
engines %>%

weight_info <-
purrr::map_df(
model,
~ get_from_env(paste0(.x, "_fit")) %>% mutate(model = .x)
) %>%
dplyr::mutate(protect = map(value, ~ .x$protect)) %>%
dplyr::select(-value) %>%
dplyr::mutate(
has_wts = purrr::map_lgl(protect, ~ any(grepl("^weight", .x))),
has_wts = ifelse(has_wts, cli::symbol$sup_1, "")
) %>%
dplyr::select(engine, mode, has_wts)

engines %>%
dplyr::left_join(weight_info, by = c("engine", "mode")) %>%
dplyr::mutate(
engine = paste0(engine, has_wts),
mode = format(paste0(mode, ": "))
) %>%
dplyr::group_by(mode) %>%
Expand All @@ -941,7 +957,7 @@ show_model_info <- function(model) {
dplyr::ungroup() %>%
dplyr::pull(lab) %>%
cat(sep = "")
cat("\n")
cat("\n", cli::symbol$sup_1, "The model can use case weights.\n\n", sep = "")
} else {
cat(" no registered engines.\n\n")
}
Expand Down
47 changes: 42 additions & 5 deletions R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,35 @@ make_call <- function(fun, ns, args, ...) {

make_form_call <- function(object, env = NULL) {
fit_args <- object$method$fit$args
uses_weights <- has_weights(env)

# Get the arguments related to data:
# In model specification code using `set_fit()`, there are two main arguments
# that dictate the data-related model arguments (e.g. 'formula', 'data', 'x',
# etc).
# The 'protect' element specifies which data arguments should not be modifiable
# by the user (as an engine argument). These have standardized names that
# follow the usual R conventions. For example, `foo(formula, data, weights)`
# and so on.
# However, some packages do not follow these naming conventions. The 'data'
# element in `set_fit()` allows use to have non-standard argument names by
# providing a named list. If function `bar(f, dat, wts)` was being used, the
# 'data' element would be `c(formula = "f", data = "dat", weights = "wts)`.
# If conventional names are used, there is no 'data' element since the values
# in 'protect' suffice.

# Get the arguments related to data arguments to insert into the model call

# Do we have conventional argument names?
if (is.null(object$method$fit$data)) {
data_args <- c(formula = "formula", data = "data")
# Set the minimum arguments for formula methods.
data_args <- object$method$fit$protect
names(data_args) <- data_args
# Case weights _could_ be used but remove the arg if they are not given:
if (!uses_weights) {
data_args <- data_args[data_args != "weights"]
}
} else {
# What are the non-conventional names?
data_args <- object$method$fit$data
}

Expand All @@ -166,6 +190,7 @@ make_form_call <- function(object, env = NULL) {
# sub in actual formula
fit_args[[ unname(data_args["formula"]) ]] <- env$formula

# TODO remove weights col from data?
if (object$engine == "spark") {
env$x <- env$data
}
Expand All @@ -178,12 +203,20 @@ make_form_call <- function(object, env = NULL) {
fit_call
}

make_xy_call <- function(object, target) {
# TODO we need something to indicate that case weights are being used.
make_xy_call <- function(object, target, env) {
fit_args <- object$method$fit$args
uses_weights <- has_weights(env)

# See the comments above in make_form_call()

# Get the arguments related to data:
if (is.null(object$method$fit$data)) {
data_args <- c(x = "x", y = "y")
data_args <- object$method$fit$protect
names(data_args) <- data_args
# Case weights _could_ be used but remove the arg if they are not given:
if (!uses_weights) {
data_args <- data_args[data_args != "weights"]
}
} else {
data_args <- object$method$fit$data
}
Expand All @@ -197,6 +230,9 @@ make_xy_call <- function(object, target) {
matrix = rlang::expr(maybe_matrix(x)),
rlang::abort(glue::glue("Invalid data type target: {target}."))
)
if (uses_weights) {
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)
}

fit_call <- make_call(
fun = object$method$fit$func["fun"],
Expand Down Expand Up @@ -269,3 +305,4 @@ min_rows <- function(num_rows, source, offset = 0) {

as.integer(num_rows)
}

4 changes: 2 additions & 2 deletions R/boost_tree_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ set_fit(
mode = "regression",
value = list(
interface = "matrix",
protect = c("x", "y"),
protect = c("x", "y", "weights"),
func = c(pkg = "parsnip", fun = "xgb_train"),
defaults = list(nthread = 1, verbose = 0)
)
Expand Down Expand Up @@ -132,7 +132,7 @@ set_fit(
mode = "classification",
value = list(
interface = "matrix",
protect = c("x", "y"),
protect = c("x", "y", "weights"),
func = c(pkg = "parsnip", fun = "xgb_train"),
defaults = list(nthread = 1, verbose = 0)
)
Expand Down
93 changes: 93 additions & 0 deletions R/case_weights.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#' Using case weights with parsnip
#'
#' Case weights are positive numeric values that influence how much each data
#' point has during the model fitting process. There are a variety of situations
#' where case weights can be used.
#'
#' tidymodels packages differentiate _how_ different types of case weights
#' should be used during the entire data analysis process, including
#' preprocessing data, model fitting, performance calculations, etc.
#'
#' The tidymodels packages require users to convert their numeric vectors to a
#' vector class that reflects how these should be used. For example, there are
#' some situations where the weights should not affect operations such as
#' centering and scaling or other preprocessing operations.
#'
#' The types of weights allowed in tidymodels are:
#'
#' * Frequency weights via [hardhat::frequency_weights()]
#' * Importance weights via [hardhat::importance_weights()]
#'
#' More types can be added by request.
#'
#' For parsnip, the [fit()] and [fit_xy] functions contain a `case_weight`
#' argument that takes these data. For Spark models, the argument value should
#' be a character value.
#'
#' @name case_weights
#' @seealso [frequency_weights()], [importance_weights()], [fit()], [fit_xy]
NULL

# ------------------------------------------------------------------------------

weights_to_numeric <- function(x, spec) {
if (is.null(x)) {
return(NULL)
} else if (spec$engine == "spark") {
# Spark wants a column name
return(x)
}

to_int <- c("hardhat_frequency_weights")
if (inherits(x, to_int)) {
x <- as.integer(x)
} else {
x <- as.numeric(x)
}
x
}

patch_formula_environment_with_case_weights <- function(formula,
data,
case_weights) {
# `lm()` and `glm()` and others use the original model function call to
# construct a call for `model.frame()`. That will normally fail because the
# formula has its own environment attached (usually the global environment)
# and it will look there for a vector named 'weights'. To account
# for this, we create a child of the `formula`'s environment and
# stash the `weights` there with the expected name and then
# reassign this as the `formula`'s environment
environment(formula) <- rlang::new_environment(
data = list(data = data, weights = case_weights),
parent = environment(formula)
)

formula
}

# ------------------------------------------------------------------------------

case_weights_allowed <- function(spec) {
mod_type <- class(spec)[1]
mod_eng <- spec$engine
mod_mode <- spec$mode

model_info <-
get_from_env(paste0(mod_type, "_fit")) %>%
dplyr::filter(engine == mod_eng & mode == mod_mode)
if (nrow(model_info) != 1) {
rlang::abort(
glue::glue(
"Error in geting model information for model {mod_type} with engine {mod_eng} and mode {mod_mode}."
)
)
}
# If weights are used, they are protected data arguments with the canonical
# name 'weights' (although this may not be the model function's argument name).
data_args <- model_info$value[[1]]$protect
any(data_args == "weights")
}

has_weights <- function(env) {
!is.null(env$weights)
}
6 changes: 6 additions & 0 deletions R/convert_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,12 @@
if (length(weights) != nrow(x)) {
rlang::abort(glue::glue("`weights` should have {nrow(x)} elements"))
}

form <- patch_formula_environment_with_case_weights(
formula = form,
data = x,
case_weights = weights
)
}

res <- list(
Expand Down
Loading