Skip to content

case weights #692

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jun 6, 2022
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
b927ffe
initial case weight work for xy models
topepo Mar 23, 2022
3bd5f37
make the env arg of eval_tidy more explicit
topepo Mar 23, 2022
73eae0b
conversion of weights class to some numeric type
topepo Mar 23, 2022
1d8507b
changes for formula call formation
topepo Mar 23, 2022
b13af41
check to see if the model can use case weights
topepo Mar 23, 2022
b72c7ce
change envir vector name to weights
topepo Mar 23, 2022
3f5f6a7
update model defs for non-standard case weight arg names
topepo Mar 23, 2022
ee75da5
could it possibly be this easy?
topepo Mar 24, 2022
0391906
better approach to handling model.frame() issues in lm()
topepo Mar 24, 2022
7f4676b
no case weights for LiblinearR (they are class weights)
topepo Mar 24, 2022
6ebce7b
add and update unit tests
topepo Mar 25, 2022
450dd48
version updates and remotes
topepo Mar 25, 2022
3bcab7a
Apply suggestions from code review
topepo Mar 28, 2022
6e036fe
function for grouped binomial data
topepo Mar 28, 2022
b7b5765
changes based on reviewer feedback
topepo Mar 28, 2022
c7cc287
re-export hardhat functions
topepo Mar 28, 2022
f2d0159
changes based on reviewer feedback
topepo Mar 28, 2022
95cd8cd
test non-standard argument names
topepo Mar 28, 2022
2929d53
temp bypass for r-devel
topepo Mar 28, 2022
2100e8f
pass case weights to xgboost
topepo Mar 29, 2022
99d1bc3
update tests for xgboost/boost_tree args
topepo Mar 29, 2022
fe2b184
add case weight summary to show_model_info()
topepo Mar 29, 2022
a4facca
added glm_grouped to pkgdown
topepo Mar 30, 2022
48c30be
more unit tests
topepo Mar 30, 2022
a2b1c1a
spark support for case weights
topepo Mar 30, 2022
118c09d
updates to documentation for case weights
topepo Mar 30, 2022
ca3e6c8
add missing topic
topepo Mar 30, 2022
ceb9c0b
more engine doc updates
topepo Mar 31, 2022
8a6f61c
added more notes in engine docs
topepo Mar 31, 2022
eb81af5
added more notes in engine docs
topepo Mar 31, 2022
69c7c14
gam weights
topepo Mar 31, 2022
33079a3
Merge branch 'main' into feature/case-weights
topepo Apr 12, 2022
c75ed66
doc update
topepo Apr 13, 2022
bc81160
revert nnet case weights
topepo Apr 13, 2022
0698b6d
S3 method to convert hardhat format to numeric
topepo Apr 14, 2022
2aaee25
Merge branch 'main' into feature/case-weights
topepo Apr 21, 2022
7c70d26
Ensure that `fit_xy()` patches the formula environment with weights (…
DavisVaughan Apr 21, 2022
2f18332
updated for latest roxygen2
topepo Apr 21, 2022
68e5c97
get xgb to stop being so chatty
topepo Apr 21, 2022
284252b
update snapshots
topepo Apr 21, 2022
64634a0
Merge branch 'main' into feature/case-weights
topepo May 19, 2022
4cb16cf
doc update
topepo May 19, 2022
f2f24a0
missing doc entry
topepo May 19, 2022
205af75
Merge branch 'main' into feature/case-weights
topepo Jun 2, 2022
a6e7849
remove convert_case_weights
topepo Jun 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 9 additions & 23 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ Depends:
Imports:
cli,
dplyr (>= 0.8.0.1),
generics (>= 0.1.0.9000),
generics (>= 0.1.2),
ggplot2,
globals,
glue,
hardhat (>= 0.1.6.9001),
hardhat (>= 0.2.0.9000),
lifecycle,
magrittr,
prettyunits,
Expand All @@ -40,9 +40,8 @@ Imports:
Suggests:
C50,
covr,
dials (>= 0.0.10.9001),
dials (>= 0.1.0),
earth,
tensorflow,
ggrepel,
keras,
kernlab,
Expand All @@ -60,30 +59,17 @@ Suggests:
rpart,
sparklyr (>= 1.0.0),
survival,
tensorflow,
testthat,
xgboost (>= 1.5.0.1)
Remotes:
tidymodels/hardhat
VignetteBuilder:
knitr
ByteCompile: true
Config/Needs/website:
C50,
dbarts,
earth,
glmnet,
keras,
kernlab,
kknn,
LiblineaR,
mgcv,
nnet,
parsnip,
randomForest,
ranger,
rpart,
rstanarm,
tidymodels/tidymodels,
tidyverse/tidytemplate,
rstudio/reticulate,
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
xgboost
Config/rcmdcheck/ignore-inconsequential-notes: true
Encoding: UTF-8
Expand Down
22 changes: 21 additions & 1 deletion R/arguments.R
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,14 @@ make_call <- function(fun, ns, args, ...) {

make_form_call <- function(object, env = NULL) {
fit_args <- object$method$fit$args
uses_weights <- !is.null(env$weights)

# Get the arguments related to data:
if (is.null(object$method$fit$data)) {
data_args <- c(formula = "formula", data = "data")
if (uses_weights) {
data_args["weights"] <- "weights"
}
} else {
data_args <- object$method$fit$data
}
Expand All @@ -165,6 +169,13 @@ make_form_call <- function(object, env = NULL) {
# sub in actual formula
fit_args[[ unname(data_args["formula"]) ]] <- env$formula

# Add in case weights symbol
if (uses_weights) {
fit_args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)
}


# TODO remove weights col from data?
if (object$engine == "spark") {
env$x <- env$data
}
Expand All @@ -177,12 +188,17 @@ make_form_call <- function(object, env = NULL) {
fit_call
}

make_xy_call <- function(object, target) {
# TODO we need something to indicate that case weights are being used.
make_xy_call <- function(object, target, env) {
fit_args <- object$method$fit$args
uses_weights <- !is.null(env$weights)

# Get the arguments related to data:
if (is.null(object$method$fit$data)) {
data_args <- c(x = "x", y = "y")
if (uses_weights) {
data_args["weights"] <- "weights"
}
} else {
data_args <- object$method$fit$data
}
Expand All @@ -196,6 +212,9 @@ make_xy_call <- function(object, target) {
matrix = rlang::expr(maybe_matrix(x)),
rlang::abort(glue::glue("Invalid data type target: {target}."))
)
if (uses_weights) {
object$method$fit$args[[ unname(data_args["weights"]) ]] <- rlang::expr(weights)
}

fit_call <- make_call(
fun = object$method$fit$func["fun"],
Expand Down Expand Up @@ -268,3 +287,4 @@ min_rows <- function(num_rows, source, offset = 0) {

as.integer(num_rows)
}

39 changes: 33 additions & 6 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#' below). A data frame containing all relevant variables (e.g.
#' outcome(s), predictors, case weights, etc). Note: when needed, a
#' \emph{named argument} should be used.
#' @param case_weights A vector of numeric case weights with underlying class of
#' "`hardhat_case_weights`". See [hardhat::frequency_weights()] for example.
#' @param control A named list with elements `verbosity` and
#' `catch`. See [control_parsnip()].
#' @param ... Not currently used; values passed here will be
Expand Down Expand Up @@ -101,6 +103,7 @@ fit.model_spec <-
function(object,
formula,
data,
case_weights = NULL,
control = control_parsnip(),
...
) {
Expand All @@ -110,6 +113,8 @@ fit.model_spec <-
if (!identical(class(control), class(control_parsnip()))) {
rlang::abort("The 'control' argument should have class 'control_parsnip'.")
}
check_case_weights(case_weights, object)

dots <- quos(...)

if (length(possible_engines(object)) == 0) {
Expand All @@ -129,15 +134,31 @@ fit.model_spec <-
}
}

if (all(c("x", "y") %in% names(dots)))
if (all(c("x", "y") %in% names(dots))) {
rlang::abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.")
}
cl <- match.call(expand.dots = TRUE)
# Create an environment with the evaluated argument objects. This will be
# used when a model call is made later.
eval_env <- rlang::env()

wts <- weights_to_numeric(case_weights)

# `lm()` and `glm()` and others use the original model function call to
# construct a call for `model.frame()`. That will normally fail because the
# formula has its own environment attached (usually the global environment)
# and it will look there for a vector named 'weights'. We've stashed that
# vector in the environment 'env' so we reset the reference environment in
# the formula to have our data objects so they can be found.
fenv <- rlang::env_clone(environment(formula))
fenv$data <- data
fenv$weights <- wts
environment(formula) <- fenv

eval_env$data <- data
eval_env$formula <- formula
eval_env$weights <- wts

fit_interface <-
check_interface(eval_env$formula, eval_env$data, cl, object)

Expand Down Expand Up @@ -206,6 +227,7 @@ fit_xy.model_spec <-
function(object,
x,
y,
case_weights = NULL,
control = control_parsnip(),
...
) {
Expand All @@ -223,6 +245,8 @@ fit_xy.model_spec <-
if (is.null(colnames(x))) {
rlang::abort("'x' should have column names.")
}
check_case_weights(case_weights, object)

object <- check_mode(object, levels(y))
dots <- quos(...)
if (is.null(object$engine)) {
Expand All @@ -245,6 +269,9 @@ fit_xy.model_spec <-
eval_env <- rlang::env()
eval_env$x <- x
eval_env$y <- y
eval_env$weights <- weights_to_numeric(case_weights)

# TODO case weights: pass in eval_env not individual elements
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
Comment on lines +271 to 272
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idk, I feel like passing in the minimal amount of information required makes it easier to understand at a glance what check_xy_interface() is actually checking. So, to me, if this function needs to check the case weights in some way, it should get a case_weights argument to tell readers of the code that it checks them


if (object$engine == "spark")
Expand Down Expand Up @@ -306,18 +333,18 @@ fit_xy.model_spec <-

# ------------------------------------------------------------------------------

eval_mod <- function(e, capture = FALSE, catch = FALSE, ...) {
eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) {
if (capture) {
if (catch) {
junk <- capture.output(res <- try(eval_tidy(e, ...), silent = TRUE))
junk <- capture.output(res <- try(eval_tidy(e, env = envir, ...), silent = TRUE))
} else {
junk <- capture.output(res <- eval_tidy(e, ...))
junk <- capture.output(res <- eval_tidy(e, env = envir, ...))
}
} else {
if (catch) {
res <- try(eval_tidy(e, ...), silent = TRUE)
res <- try(eval_tidy(e, env = envir, ...), silent = TRUE)
} else {
res <- eval_tidy(e, ...)
res <- eval_tidy(e, env = envir, ...)
Comment on lines -312 to +344
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, since we are already touching this I have another comment:

eval_tidy() has a signature of eval_tidy(expr, data = NULL, env = caller_env()), so unless the ... have data in them, then they will never be used. Which makes me think, can we remove the ... entirely?

  • If we do happen to pass through the data, then we should replace the ... with an explicit data argument in eval_mod() and pass that through instead.

-If we never use the data argument of eval_tidy(), we should just remove the dots

}
}
res
Expand Down
46 changes: 41 additions & 5 deletions R/fit_helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ form_form <-
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
env = env,
envir = env,
...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the ... from eval_mod(), that means we shouldn't pass through the dots here. Which I think makes sense? I don't currently see any reason why we do this

(We should check all other uses of eval_mod() because I think we pass the dots through in more places than just this one)

),
gcFirst = FALSE
Expand All @@ -49,7 +49,7 @@ form_form <-
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
env = env,
envir = env,
...
)
elapsed <- list(elapsed = NA_real_)
Expand Down Expand Up @@ -88,7 +88,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
# sub in arguments to actual syntax for corresponding engine
object <- translate(object, engine = object$engine)

fit_call <- make_xy_call(object, target)
fit_call <- make_xy_call(object, target, env)

res <- list(lvl = levels(env$y), spec = object)

Expand All @@ -98,7 +98,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
env = env,
envir = env,
...
),
gcFirst = FALSE
Expand All @@ -108,7 +108,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
fit_call,
capture = control$verbosity == 0,
catch = control$catch,
env = env,
envir = env,
...
)
elapsed <- list(elapsed = NA_real_)
Expand Down Expand Up @@ -200,3 +200,39 @@ xy_form <- function(object, env, control, ...) {
res
}


weights_to_numeric <- function(x) {
if (is.null(x)) {
return(NULL)
}

to_int <- c("hardhat_frequency_weights")
if (inherits(x, to_int)) {
x <- as.integer(x)
} else {
x <- as.numeric(x)
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would either use:

if (hardhat::is_frequency_weights(x)) {
  x <- as.integer(x)
} else if (hardhat::is_importance_weights(x)) {
  x <- as.double(x)
} else {
  abort("Unknown type of case weights.", .internal = TRUE)
}

Or:

if (hardhat::is_frequency_weights(x)) {
  x <- as.integer(x)
} else {
  x <- as.double(x)
}

The first one is if we want to be very selective about the types of case weights that parsnip supports.

The second one is if we generally just want to try and convert any kind of case weights to double, with a special case for frequency-weights.

Mainly I want to avoid hardcoding "hardhat_frequency_weights" anywhere, that's what the is_frequency_weights() helper is for

x
}

case_weights_allowed <- function(spec) {
mod_type <- class(spec)[1]
mod_eng <- spec$engine
mod_mode <- spec$mode

model_info <-
get_from_env(paste0(mod_type, "_fit")) %>%
dplyr::filter(engine == mod_eng & mode == mod_mode)
if (nrow(model_info) != 1) {
rlang::abort(
glue::glue(
"Error in geting model information for model {mod_type} with engine {mod_eng} and mode {mod_mode}."
)
)
}
# If weights are used, they are protected data arguments with the canonical
# name 'weights' (although this may not be the model function's argument name).
data_args <- model_info$value[[1]]$protect
any(data_args == "weights")
}

2 changes: 1 addition & 1 deletion R/logistic_reg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ set_fit(
mode = "classification",
value = list(
interface = "matrix",
protect = c("x", "y", "wi"),
protect = c("x", "y"),
data = c(x = "data", y = "target"),
func = c(pkg = "LiblineaR", fun = "LiblineaR"),
defaults = list(verbose = FALSE)
Expand Down
15 changes: 15 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,18 @@ stan_conf_int <- function(object, newdata) {

penalty
}


check_case_weights <- function(x, spec) {
if (is.null(x)) {
return(invisible(NULL))
}
if (!inherits(x, "hardhat_case_weights")) {
rlang::abort("'case_weights' should be a single numeric vector of class 'hardhat_case_weights'.")
}
allowed <- case_weights_allowed(spec)
if (!allowed) {
rlang::abort("Case weights are not enabled by the underlying model implementation.")
}
invisible(NULL)
}
2 changes: 1 addition & 1 deletion R/parsnip-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ utils::globalVariables(
"sub_neighbors", ".pred_class", "x", "y", "predictor_indicators",
"compute_intercept", "remove_intercept", "estimate", "term",
"call_info", "component", "component_id", "func", "tunable", "label",
"pkg", ".order", "item", "tunable", "has_ext"
"pkg", ".order", "item", "tunable", "has_ext", "weights"
)
)

Expand Down
6 changes: 4 additions & 2 deletions R/rand_forest_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ set_fit(
mode = "classification",
value = list(
interface = "data.frame",
protect = c("x", "y", "case.weights"),
data = c(x = "x", y = "y", weights = "case.weights"),
protect = c("x", "y", "weights"),
func = c(pkg = "ranger", fun = "ranger"),
defaults =
list(
Expand Down Expand Up @@ -151,7 +152,8 @@ set_fit(
mode = "regression",
value = list(
interface = "data.frame",
protect = c("x", "y", "case.weights"),
data = c(x = "x", y = "y", weights = "case.weights"),
protect = c("x", "y", "weights"),
func = c(pkg = "ranger", fun = "ranger"),
defaults =
list(
Expand Down
5 changes: 3 additions & 2 deletions R/svm_linear_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ set_fit(
mode = "regression",
value = list(
interface = "matrix",
protect = c("x", "y", "wi"),
protect = c("x", "y"),
data = c(x = "data", y = "target"),
func = c(pkg = "LiblineaR", fun = "LiblineaR"),
defaults = list(type = 11)
Expand All @@ -47,7 +47,8 @@ set_fit(
value = list(
interface = "matrix",
data = c(x = "data", y = "target"),
protect = c("x", "y", "wi"),
protect = c("x", "y"),
data = c(x = "data", y = "target"),
func = c(pkg = "LiblineaR", fun = "LiblineaR"),
defaults = list(type = 1)
)
Expand Down
Loading