-
Notifications
You must be signed in to change notification settings - Fork 92
case weights #692
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
case weights #692
Changes from 12 commits
b927ffe
3bd5f37
73eae0b
1d8507b
b13af41
b72c7ce
3f5f6a7
ee75da5
0391906
7f4676b
6ebce7b
450dd48
3bcab7a
6e036fe
b7b5765
c7cc287
f2d0159
95cd8cd
2929d53
2100e8f
99d1bc3
fe2b184
a4facca
48c30be
a2b1c1a
118c09d
ca3e6c8
ceb9c0b
8a6f61c
eb81af5
69c7c14
33079a3
c75ed66
bc81160
0698b6d
2aaee25
7c70d26
2f18332
68e5c97
284252b
64634a0
4cb16cf
f2f24a0
205af75
a6e7849
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,8 @@ | |
#' below). A data frame containing all relevant variables (e.g. | ||
#' outcome(s), predictors, case weights, etc). Note: when needed, a | ||
#' \emph{named argument} should be used. | ||
#' @param case_weights A vector of numeric case weights with underlying class of | ||
#' "`hardhat_case_weights`". See [hardhat::frequency_weights()] for example. | ||
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#' @param control A named list with elements `verbosity` and | ||
#' `catch`. See [control_parsnip()]. | ||
#' @param ... Not currently used; values passed here will be | ||
|
@@ -101,6 +103,7 @@ fit.model_spec <- | |
function(object, | ||
formula, | ||
data, | ||
case_weights = NULL, | ||
control = control_parsnip(), | ||
... | ||
) { | ||
|
@@ -110,6 +113,8 @@ fit.model_spec <- | |
if (!identical(class(control), class(control_parsnip()))) { | ||
rlang::abort("The 'control' argument should have class 'control_parsnip'.") | ||
} | ||
check_case_weights(case_weights, object) | ||
|
||
dots <- quos(...) | ||
|
||
if (length(possible_engines(object)) == 0) { | ||
|
@@ -129,15 +134,31 @@ fit.model_spec <- | |
} | ||
} | ||
|
||
if (all(c("x", "y") %in% names(dots))) | ||
if (all(c("x", "y") %in% names(dots))) { | ||
rlang::abort("`fit.model_spec()` is for the formula methods. Use `fit_xy()` instead.") | ||
} | ||
cl <- match.call(expand.dots = TRUE) | ||
# Create an environment with the evaluated argument objects. This will be | ||
# used when a model call is made later. | ||
eval_env <- rlang::env() | ||
|
||
wts <- weights_to_numeric(case_weights) | ||
|
||
# `lm()` and `glm()` and others use the original model function call to | ||
# construct a call for `model.frame()`. That will normally fail because the | ||
# formula has its own environment attached (usually the global environment) | ||
# and it will look there for a vector named 'weights'. We've stashed that | ||
# vector in the environment 'env' so we reset the reference environment in | ||
# the formula to have our data objects so they can be found. | ||
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fenv <- rlang::env_clone(environment(formula)) | ||
fenv$data <- data | ||
fenv$weights <- wts | ||
environment(formula) <- fenv | ||
topepo marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
eval_env$data <- data | ||
eval_env$formula <- formula | ||
eval_env$weights <- wts | ||
|
||
fit_interface <- | ||
check_interface(eval_env$formula, eval_env$data, cl, object) | ||
|
||
|
@@ -206,6 +227,7 @@ fit_xy.model_spec <- | |
function(object, | ||
x, | ||
y, | ||
case_weights = NULL, | ||
control = control_parsnip(), | ||
... | ||
) { | ||
|
@@ -223,6 +245,8 @@ fit_xy.model_spec <- | |
if (is.null(colnames(x))) { | ||
rlang::abort("'x' should have column names.") | ||
} | ||
check_case_weights(case_weights, object) | ||
|
||
object <- check_mode(object, levels(y)) | ||
dots <- quos(...) | ||
if (is.null(object$engine)) { | ||
|
@@ -245,6 +269,9 @@ fit_xy.model_spec <- | |
eval_env <- rlang::env() | ||
eval_env$x <- x | ||
eval_env$y <- y | ||
eval_env$weights <- weights_to_numeric(case_weights) | ||
|
||
# TODO case weights: pass in eval_env not individual elements | ||
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object) | ||
Comment on lines
+271
to
272
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Idk, I feel like passing in the minimal amount of information required makes it easier to understand at a glance what |
||
|
||
if (object$engine == "spark") | ||
|
@@ -306,18 +333,18 @@ fit_xy.model_spec <- | |
|
||
# ------------------------------------------------------------------------------ | ||
|
||
eval_mod <- function(e, capture = FALSE, catch = FALSE, ...) { | ||
eval_mod <- function(e, capture = FALSE, catch = FALSE, envir = NULL, ...) { | ||
if (capture) { | ||
if (catch) { | ||
junk <- capture.output(res <- try(eval_tidy(e, ...), silent = TRUE)) | ||
junk <- capture.output(res <- try(eval_tidy(e, env = envir, ...), silent = TRUE)) | ||
} else { | ||
junk <- capture.output(res <- eval_tidy(e, ...)) | ||
junk <- capture.output(res <- eval_tidy(e, env = envir, ...)) | ||
} | ||
} else { | ||
if (catch) { | ||
res <- try(eval_tidy(e, ...), silent = TRUE) | ||
res <- try(eval_tidy(e, env = envir, ...), silent = TRUE) | ||
} else { | ||
res <- eval_tidy(e, ...) | ||
res <- eval_tidy(e, env = envir, ...) | ||
Comment on lines
-312
to
+344
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok, since we are already touching this I have another comment:
-If we never use the |
||
} | ||
} | ||
res | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,7 @@ form_form <- | |
fit_call, | ||
capture = control$verbosity == 0, | ||
catch = control$catch, | ||
env = env, | ||
envir = env, | ||
... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we remove the (We should check all other uses of |
||
), | ||
gcFirst = FALSE | ||
|
@@ -49,7 +49,7 @@ form_form <- | |
fit_call, | ||
capture = control$verbosity == 0, | ||
catch = control$catch, | ||
env = env, | ||
envir = env, | ||
... | ||
) | ||
elapsed <- list(elapsed = NA_real_) | ||
|
@@ -88,7 +88,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { | |
# sub in arguments to actual syntax for corresponding engine | ||
object <- translate(object, engine = object$engine) | ||
|
||
fit_call <- make_xy_call(object, target) | ||
fit_call <- make_xy_call(object, target, env) | ||
|
||
res <- list(lvl = levels(env$y), spec = object) | ||
|
||
|
@@ -98,7 +98,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { | |
fit_call, | ||
capture = control$verbosity == 0, | ||
catch = control$catch, | ||
env = env, | ||
envir = env, | ||
... | ||
), | ||
gcFirst = FALSE | ||
|
@@ -108,7 +108,7 @@ xy_xy <- function(object, env, control, target = "none", ...) { | |
fit_call, | ||
capture = control$verbosity == 0, | ||
catch = control$catch, | ||
env = env, | ||
envir = env, | ||
... | ||
) | ||
elapsed <- list(elapsed = NA_real_) | ||
|
@@ -200,3 +200,39 @@ xy_form <- function(object, env, control, ...) { | |
res | ||
} | ||
|
||
|
||
weights_to_numeric <- function(x) { | ||
if (is.null(x)) { | ||
return(NULL) | ||
} | ||
|
||
to_int <- c("hardhat_frequency_weights") | ||
if (inherits(x, to_int)) { | ||
x <- as.integer(x) | ||
} else { | ||
x <- as.numeric(x) | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would either use: if (hardhat::is_frequency_weights(x)) {
x <- as.integer(x)
} else if (hardhat::is_importance_weights(x)) {
x <- as.double(x)
} else {
abort("Unknown type of case weights.", .internal = TRUE)
} Or: if (hardhat::is_frequency_weights(x)) {
x <- as.integer(x)
} else {
x <- as.double(x)
} The first one is if we want to be very selective about the types of case weights that parsnip supports. The second one is if we generally just want to try and convert any kind of case weights to double, with a special case for frequency-weights. Mainly I want to avoid hardcoding |
||
x | ||
} | ||
|
||
case_weights_allowed <- function(spec) { | ||
mod_type <- class(spec)[1] | ||
mod_eng <- spec$engine | ||
mod_mode <- spec$mode | ||
|
||
model_info <- | ||
get_from_env(paste0(mod_type, "_fit")) %>% | ||
dplyr::filter(engine == mod_eng & mode == mod_mode) | ||
if (nrow(model_info) != 1) { | ||
rlang::abort( | ||
glue::glue( | ||
"Error in geting model information for model {mod_type} with engine {mod_eng} and mode {mod_mode}." | ||
) | ||
) | ||
} | ||
# If weights are used, they are protected data arguments with the canonical | ||
# name 'weights' (although this may not be the model function's argument name). | ||
data_args <- model_info$value[[1]]$protect | ||
any(data_args == "weights") | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.