-
Notifications
You must be signed in to change notification settings - Fork 92
case weights #692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
case weights #692
Changes from all commits
Commits
Show all changes
45 commits
Select commit
Hold shift + click to select a range
b927ffe
initial case weight work for xy models
topepo 3bd5f37
make the env arg of eval_tidy more explicit
topepo 73eae0b
conversion of weights class to some numeric type
topepo 1d8507b
changes for formula call formation
topepo b13af41
check to see if the model can use case weights
topepo b72c7ce
change envir vector name to weights
topepo 3f5f6a7
update model defs for non-standard case weight arg names
topepo ee75da5
could it possibly be this easy?
topepo 0391906
better approach to handling model.frame() issues in lm()
topepo 7f4676b
no case weights for LiblinearR (they are class weights)
topepo 6ebce7b
add and update unit tests
topepo 450dd48
version updates and remotes
topepo 3bcab7a
Apply suggestions from code review
topepo 6e036fe
function for grouped binomial data
topepo b7b5765
changes based on reviewer feedback
topepo c7cc287
re-export hardhat functions
topepo f2d0159
changes based on reviewer feedback
topepo 95cd8cd
test non-standard argument names
topepo 2929d53
temp bypass for r-devel
topepo 2100e8f
pass case weights to xgboost
topepo 99d1bc3
update tests for xgboost/boost_tree args
topepo fe2b184
add case weight summary to show_model_info()
topepo a4facca
added glm_grouped to pkgdown
topepo 48c30be
more unit tests
topepo a2b1c1a
spark support for case weights
topepo 118c09d
updates to documentation for case weights
topepo ca3e6c8
add missing topic
topepo ceb9c0b
more engine doc updates
topepo 8a6f61c
added more notes in engine docs
topepo eb81af5
added more notes in engine docs
topepo 69c7c14
gam weights
topepo 33079a3
Merge branch 'main' into feature/case-weights
topepo c75ed66
doc update
topepo bc81160
revert nnet case weights
topepo 0698b6d
S3 method to convert hardhat format to numeric
topepo 2aaee25
Merge branch 'main' into feature/case-weights
topepo 7c70d26
Ensure that `fit_xy()` patches the formula environment with weights (…
DavisVaughan 2f18332
updated for latest roxygen2
topepo 68e5c97
get xgb to stop being so chatty
topepo 284252b
update snapshots
topepo 64634a0
Merge branch 'main' into feature/case-weights
topepo 4cb16cf
doc update
topepo f2f24a0
missing doc entry
topepo 205af75
Merge branch 'main' into feature/case-weights
topepo a6e7849
remove convert_case_weights
topepo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
#' Using case weights with parsnip | ||
#' | ||
#' Case weights are positive numeric values that influence how much each data | ||
#' point has during the model fitting process. There are a variety of situations | ||
#' where case weights can be used. | ||
#' | ||
#' tidymodels packages differentiate _how_ different types of case weights | ||
#' should be used during the entire data analysis process, including | ||
#' preprocessing data, model fitting, performance calculations, etc. | ||
#' | ||
#' The tidymodels packages require users to convert their numeric vectors to a | ||
#' vector class that reflects how these should be used. For example, there are | ||
#' some situations where the weights should not affect operations such as | ||
#' centering and scaling or other preprocessing operations. | ||
#' | ||
#' The types of weights allowed in tidymodels are: | ||
#' | ||
#' * Frequency weights via [hardhat::frequency_weights()] | ||
#' * Importance weights via [hardhat::importance_weights()] | ||
#' | ||
#' More types can be added by request. | ||
#' | ||
#' For parsnip, the [fit()] and [fit_xy] functions contain a `case_weight` | ||
#' argument that takes these data. For Spark models, the argument value should | ||
#' be a character value. | ||
#' | ||
#' @name case_weights | ||
#' @seealso [frequency_weights()], [importance_weights()], [fit()], [fit_xy] | ||
NULL | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
weights_to_numeric <- function(x, spec) { | ||
if (is.null(x)) { | ||
return(NULL) | ||
} else if (spec$engine == "spark") { | ||
# Spark wants a column name | ||
return(x) | ||
} | ||
|
||
to_int <- c("hardhat_frequency_weights") | ||
if (inherits(x, to_int)) { | ||
x <- as.integer(x) | ||
} else { | ||
x <- as.numeric(x) | ||
} | ||
x | ||
} | ||
|
||
patch_formula_environment_with_case_weights <- function(formula, | ||
data, | ||
case_weights) { | ||
# `lm()` and `glm()` and others use the original model function call to | ||
# construct a call for `model.frame()`. That will normally fail because the | ||
# formula has its own environment attached (usually the global environment) | ||
# and it will look there for a vector named 'weights'. To account | ||
# for this, we create a child of the `formula`'s environment and | ||
# stash the `weights` there with the expected name and then | ||
# reassign this as the `formula`'s environment | ||
environment(formula) <- rlang::new_environment( | ||
data = list(data = data, weights = case_weights), | ||
parent = environment(formula) | ||
) | ||
|
||
formula | ||
} | ||
|
||
# ------------------------------------------------------------------------------ | ||
|
||
case_weights_allowed <- function(spec) { | ||
mod_type <- class(spec)[1] | ||
mod_eng <- spec$engine | ||
mod_mode <- spec$mode | ||
|
||
model_info <- | ||
get_from_env(paste0(mod_type, "_fit")) %>% | ||
dplyr::filter(engine == mod_eng & mode == mod_mode) | ||
if (nrow(model_info) != 1) { | ||
rlang::abort( | ||
glue::glue( | ||
"Error in geting model information for model {mod_type} with engine {mod_eng} and mode {mod_mode}." | ||
) | ||
) | ||
} | ||
# If weights are used, they are protected data arguments with the canonical | ||
# name 'weights' (although this may not be the model function's argument name). | ||
data_args <- model_info$value[[1]]$protect | ||
any(data_args == "weights") | ||
} | ||
|
||
has_weights <- function(env) { | ||
!is.null(env$weights) | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.