Skip to content

Commit 0698b6d

Browse files
committed
S3 method to convert hardhat format to numeric
1 parent bc81160 commit 0698b6d

File tree

4 files changed

+79
-0
lines changed

4 files changed

+79
-0
lines changed

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(augment,model_fit)
4+
S3method(convert_case_weights,default)
5+
S3method(convert_case_weights,hardhat_frequency_weights)
6+
S3method(convert_case_weights,hardhat_importance_weights)
47
S3method(extract_fit_engine,model_fit)
58
S3method(extract_parameter_dials,model_spec)
69
S3method(extract_parameter_set_dials,model_spec)
@@ -178,6 +181,7 @@ export(check_model_doesnt_exist)
178181
export(check_model_exists)
179182
export(contr_one_hot)
180183
export(control_parsnip)
184+
export(convert_case_weights)
181185
export(convert_stan_interval)
182186
export(ctree_train)
183187
export(cubist_rules)

R/case_weights.R

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#' @seealso [frequency_weights()], [importance_weights()], [fit()], [fit_xy]
2929
NULL
3030

31+
# ------------------------------------------------------------------------------
3132

3233
weights_to_numeric <- function(x, spec) {
3334
if (is.null(x)) {
@@ -46,6 +47,51 @@ weights_to_numeric <- function(x, spec) {
4647
x
4748
}
4849

50+
#' Convert case weights to final from
51+
#'
52+
#' tidymodels requires case weights to have special classes. To use them in
53+
#' model fitting or performance evaluation, they need to be converted to
54+
#' numeric.
55+
#' @param x A vector with class `"hardhat_case_weights"`.
56+
#' @param where The location where they will be used: `"parsnip"` or
57+
#' `"yardstick"`.
58+
#' @return A numeric vector or NULL.
59+
#' @export
60+
convert_case_weights <- function(x, where = "parsnip", ...) {
61+
UseMethod("convert_case_weights")
62+
}
63+
64+
#' @export
65+
convert_case_weights.default <- function(x, where = "parsnip", ...) {
66+
where <- rlang::arg_match0(where, c("parsnip", "yardstick"))
67+
if (!inherits(x, "hardhat_case_weights")) {
68+
rlang::abort("'case_weights' should be vector of class 'hardhat_case_weights'")
69+
}
70+
invisible(NULL)
71+
}
72+
73+
#' @export
74+
#' @rdname convert_case_weights
75+
convert_case_weights.hardhat_importance_weights <-
76+
function(x, where = "parsnip", ...) {
77+
if (where == "parsnip") {
78+
x <- as.double(x)
79+
} else {
80+
x <- NULL
81+
}
82+
x
83+
}
84+
85+
#' @export
86+
#' @rdname convert_case_weights
87+
convert_case_weights.hardhat_frequency_weights <-
88+
function(x, where = "parsnip", ...) {
89+
as.integer(x)
90+
}
91+
92+
93+
# ------------------------------------------------------------------------------
94+
4995
case_weights_allowed <- function(spec) {
5096
mod_type <- class(spec)[1]
5197
mod_eng <- spec$engine

man/convert_case_weights.Rd

Lines changed: 28 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

parsnip.Rproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,6 @@ StripTrailingWhitespace: Yes
1717

1818
BuildType: Package
1919
PackageUseDevtools: Yes
20+
PackageCleanBeforeInstall: Yes
2021
PackageInstallArgs: --no-multiarch --with-keep.source
2122
PackageRoxygenize: rd,collate,namespace

0 commit comments

Comments
 (0)