Skip to content

Commit 2f386d2

Browse files
Merge pull request #1093 from tidymodels/cli-check_args
switch to {cli} in check_args() functions
2 parents 7b7e118 + 21c0e91 commit 2f386d2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

48 files changed

+578
-201
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.2.1.9000
3+
Version: 1.2.1.9001
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),

R/bag_tree.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,7 @@ update.bag_tree <-
8585
# ------------------------------------------------------------------------------
8686

8787
#' @export
88-
check_args.bag_tree <- function(object) {
89-
if (object$engine == "C5.0" && object$mode == "regression")
90-
stop("C5.0 is classification only.", call. = FALSE)
88+
check_args.bag_tree <- function(object, call = rlang::caller_env()) {
9189
invisible(object)
9290
}
9391

R/boost_tree.R

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -164,23 +164,15 @@ translate.boost_tree <- function(x, engine = x$engine, ...) {
164164
# ------------------------------------------------------------------------------
165165

166166
#' @export
167-
check_args.boost_tree <- function(object) {
167+
check_args.boost_tree <- function(object, call = rlang::caller_env()) {
168168

169169
args <- lapply(object$args, rlang::eval_tidy)
170170

171-
if (is.numeric(args$trees) && args$trees < 0) {
172-
rlang::abort("`trees` should be >= 1.")
173-
}
174-
if (is.numeric(args$sample_size) && (args$sample_size < 0 | args$sample_size > 1)) {
175-
rlang::abort("`sample_size` should be within [0,1].")
176-
}
177-
if (is.numeric(args$tree_depth) && args$tree_depth < 0) {
178-
rlang::abort("`tree_depth` should be >= 1.")
179-
}
180-
if (is.numeric(args$min_n) && args$min_n < 0) {
181-
rlang::abort("`min_n` should be >= 1.")
182-
}
183-
171+
check_number_whole(args$trees, min = 0, allow_null = TRUE, call = call, arg = "trees")
172+
check_number_decimal(args$sample_size, min = 0, max = 1, allow_null = TRUE, call = call, arg = "sample_size")
173+
check_number_whole(args$tree_depth, min = 0, allow_null = TRUE, call = call, arg = "tree_depth")
174+
check_number_whole(args$min_n, min = 0, allow_null = TRUE, call = call, arg = "min_n")
175+
184176
invisible(object)
185177
}
186178

R/c5_rules.R

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -111,32 +111,23 @@ update.C5_rules <-
111111
# make work in different places
112112

113113
#' @export
114-
check_args.C5_rules <- function(object) {
114+
check_args.C5_rules <- function(object, call = rlang::caller_env()) {
115115

116116
args <- lapply(object$args, rlang::eval_tidy)
117117

118-
if (is.numeric(args$trees)) {
119-
if (length(args$trees) > 1) {
120-
rlang::abort("Only a single value of `trees` is used.")
121-
}
122-
msg <- "The number of trees should be >= 1 and <= 100. Truncating the value."
123-
if (args$trees > 100) {
124-
object$args$trees <-
125-
rlang::new_quosure(100L, env = rlang::empty_env())
126-
rlang::warn(msg)
127-
}
128-
if (args$trees < 1) {
129-
object$args$trees <-
130-
rlang::new_quosure(1L, env = rlang::empty_env())
131-
rlang::warn(msg)
132-
}
118+
check_number_whole(args$min_n, allow_null = TRUE, call = call, arg = "min_n")
119+
check_number_whole(args$tree, allow_null = TRUE, call = call, arg = "tree")
133120

121+
msg <- "The number of trees should be {.code >= 1} and {.code <= 100}"
122+
if (!(is.null(args$trees)) && args$trees > 100) {
123+
object$args$trees <- rlang::new_quosure(100L, env = rlang::empty_env())
124+
cli::cli_warn(c(msg, "Truncating to 100."))
134125
}
135-
if (is.numeric(args$min_n)) {
136-
if (length(args$min_n) > 1) {
137-
rlang::abort("Only a single `min_n`` value is used.")
138-
}
126+
if (!(is.null(args$trees)) && args$trees < 1) {
127+
object$args$trees <- rlang::new_quosure(1L, env = rlang::empty_env())
128+
cli::cli_warn(c(msg, "Truncating to 1."))
139129
}
130+
140131
invisible(object)
141132
}
142133

R/cubist_rules.R

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -135,44 +135,36 @@ update.cubist_rules <-
135135
# make work in different places
136136

137137
#' @export
138-
check_args.cubist_rules <- function(object) {
138+
check_args.cubist_rules <- function(object, call = rlang::caller_env()) {
139139

140140
args <- lapply(object$args, rlang::eval_tidy)
141141

142-
if (is.numeric(args$committees)) {
143-
if (length(args$committees) > 1) {
144-
rlang::abort("Only a single committee member is used.")
145-
}
146-
msg <- "The number of committees should be >= 1 and <= 100. Truncating the value."
147-
if (args$committees > 100) {
148-
object$args$committees <-
149-
rlang::new_quosure(100L, env = rlang::empty_env())
150-
rlang::warn(msg)
151-
}
152-
if (args$committees < 1) {
153-
object$args$committees <-
154-
rlang::new_quosure(1L, env = rlang::empty_env())
155-
rlang::warn(msg)
156-
}
142+
check_number_whole(args$committees, allow_null = TRUE, call = call, arg = "committees")
157143

158-
}
159-
if (is.numeric(args$neighbors)) {
160-
if (length(args$neighbors) > 1) {
161-
rlang::abort("Only a single neighbors value is used.")
162-
}
163-
msg <- "The number of neighbors should be >= 0 and <= 9. Truncating the value."
164-
if (args$neighbors > 9) {
165-
object$args$neighbors <-
166-
rlang::new_quosure(9L, env = rlang::empty_env())
167-
rlang::warn(msg)
168-
}
169-
if (args$neighbors < 0) {
170-
object$args$neighbors <-
171-
rlang::new_quosure(0L, env = rlang::empty_env())
172-
rlang::warn(msg)
144+
msg <- "The number of committees should be {.code >= 1} and {.code <= 100}."
145+
if (!(is.null(args$committees)) && args$committees > 100) {
146+
object$args$committees <-
147+
rlang::new_quosure(100L, env = rlang::empty_env())
148+
cli::cli_warn(c(msg, "Truncating to 100."))
173149
}
150+
if (!(is.null(args$committees)) && args$committees < 1) {
151+
object$args$committees <-
152+
rlang::new_quosure(1L, env = rlang::empty_env())
153+
cli::cli_warn(c(msg, "Truncating to 1."))
154+
}
155+
156+
check_number_whole(args$neighbors, allow_null = TRUE, call = call, arg = "neighbors")
174157

158+
msg <- "The number of neighbors should be {.code >= 0} and {.code <= 9}."
159+
if (!(is.null(args$neighbors)) && args$neighbors > 9) {
160+
object$args$neighbors <- rlang::new_quosure(9L, env = rlang::empty_env())
161+
cli::cli_warn(c(msg, "Truncating to 9."))
175162
}
163+
if (!(is.null(args$neighbors)) && args$neighbors < 0) {
164+
object$args$neighbors <- rlang::new_quosure(0L, env = rlang::empty_env())
165+
cli::cli_warn(c(msg, "Truncating to 0."))
166+
}
167+
176168
invisible(object)
177169
}
178170

R/decision_tree.R

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,7 @@ translate.decision_tree <- function(x, engine = x$engine, ...) {
128128
# ------------------------------------------------------------------------------
129129

130130
#' @export
131-
check_args.decision_tree <- function(object) {
132-
if (object$engine == "C5.0" && object$mode == "regression")
133-
rlang::abort("C5.0 is classification only.")
131+
check_args.decision_tree <- function(object, call = rlang::caller_env()) {
134132
invisible(object)
135133
}
136134

R/discrim_flexible.R

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,21 +85,14 @@ update.discrim_flexible <-
8585
# ------------------------------------------------------------------------------
8686

8787
#' @export
88-
check_args.discrim_flexible <- function(object) {
88+
check_args.discrim_flexible <- function(object, call = rlang::caller_env()) {
8989

9090
args <- lapply(object$args, rlang::eval_tidy)
9191

92-
if (is.numeric(args$prod_degree) && args$prod_degree < 0)
93-
stop("`prod_degree` should be >= 1", call. = FALSE)
94-
95-
if (is.numeric(args$num_terms) && args$num_terms < 0)
96-
stop("`num_terms` should be >= 1", call. = FALSE)
97-
98-
if (!is.character(args$prune_method) &&
99-
!is.null(args$prune_method) &&
100-
!is.character(args$prune_method))
101-
stop("`prune_method` should be a single string value", call. = FALSE)
102-
92+
check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree")
93+
check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms")
94+
check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method")
95+
10396
invisible(object)
10497
}
10598

R/discrim_linear.R

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,11 @@ update.discrim_linear <-
8080
# ------------------------------------------------------------------------------
8181

8282
#' @export
83-
check_args.discrim_linear <- function(object) {
83+
check_args.discrim_linear <- function(object, call = rlang::caller_env()) {
8484

8585
args <- lapply(object$args, rlang::eval_tidy)
8686

87-
if (all(is.numeric(args$penalty)) && any(args$penalty < 0)) {
88-
stop("The amount of regularization should be >= 0", call. = FALSE)
89-
}
87+
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")
9088

9189
invisible(object)
9290
}

R/discrim_regularized.R

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,18 +95,13 @@ update.discrim_regularized <-
9595
# ------------------------------------------------------------------------------
9696

9797
#' @export
98-
check_args.discrim_regularized <- function(object) {
98+
check_args.discrim_regularized <- function(object, call = rlang::caller_env()) {
9999

100100
args <- lapply(object$args, rlang::eval_tidy)
101101

102-
if (is.numeric(args$frac_common_cov) &&
103-
(args$frac_common_cov < 0 | args$frac_common_cov > 1)) {
104-
stop("The common covariance fraction should be between zero and one", call. = FALSE)
105-
}
106-
if (is.numeric(args$frac_identity) &&
107-
(args$frac_identity < 0 | args$frac_identity > 1)) {
108-
stop("The identity matrix fraction should be between zero and one", call. = FALSE)
109-
}
102+
check_number_decimal(args$frac_common_cov, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_common_cov")
103+
check_number_decimal(args$frac_identity, min = 0, max = 1, allow_null = TRUE, call = call, arg = "frac_identity")
104+
110105
invisible(object)
111106
}
112107

R/fit_helpers.R

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# data to formula/data objects and so on.
55

66
form_form <-
7-
function(object, control, env, ...) {
7+
function(object, control, env, ..., call = rlang::caller_env()) {
88

99
if (inherits(env$data, "data.frame")) {
1010
check_outcome(eval_tidy(rlang::f_lhs(env$formula), env$data), object)
@@ -32,7 +32,7 @@ form_form <-
3232
}
3333

3434
# evaluate quoted args once here to check them
35-
object <- check_args(object)
35+
object <- check_args(object, call = call)
3636

3737
# sub in arguments to actual syntax for corresponding engine
3838
object <- translate(object, engine = object$engine)
@@ -60,7 +60,12 @@ form_form <-
6060
res
6161
}
6262

63-
xy_xy <- function(object, env, control, target = "none", ...) {
63+
xy_xy <- function(object,
64+
env,
65+
control,
66+
target = "none",
67+
...,
68+
call = rlang::caller_env()) {
6469

6570
if (inherits(env$x, "tbl_spark") | inherits(env$y, "tbl_spark"))
6671
rlang::abort("spark objects can only be used with the formula interface to `fit()`")
@@ -83,7 +88,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
8388
}
8489

8590
# evaluate quoted args once here to check them
86-
object <- check_args(object)
91+
object <- check_args(object, call = call)
8792

8893
# sub in arguments to actual syntax for corresponding engine
8994
object <- translate(object, engine = object$engine)
@@ -114,7 +119,7 @@ xy_xy <- function(object, env, control, target = "none", ...) {
114119
}
115120

116121
form_xy <- function(object, control, env,
117-
target = "none", ...) {
122+
target = "none", ..., call = rlang::caller_env()) {
118123

119124
encoding_info <-
120125
get_encoding(class(object)[1]) %>%
@@ -138,7 +143,8 @@ form_xy <- function(object, control, env,
138143
object = object,
139144
env = env, #weights!
140145
control = control,
141-
target = target
146+
target = target,
147+
call = call
142148
)
143149
data_obj$y_var <- all.vars(rlang::f_lhs(env$formula))
144150
data_obj$x <- NULL

R/linear_reg.R

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,12 @@ update.linear_reg <-
106106
# ------------------------------------------------------------------------------
107107

108108
#' @export
109-
check_args.linear_reg <- function(object) {
109+
check_args.linear_reg <- function(object, call = rlang::caller_env()) {
110110

111111
args <- lapply(object$args, rlang::eval_tidy)
112112

113-
if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
114-
rlang::abort("The amount of regularization should be >= 0.")
115-
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
116-
rlang::abort("The mixture proportion should be within [0,1].")
117-
if (is.numeric(args$mixture) && length(args$mixture) > 1)
118-
rlang::abort("Only one value of `mixture` is allowed.")
113+
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
114+
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")
119115

120116
invisible(object)
121117
}

R/logistic_reg.R

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -135,25 +135,32 @@ update.logistic_reg <-
135135
# ------------------------------------------------------------------------------
136136

137137
#' @export
138-
check_args.logistic_reg <- function(object) {
138+
check_args.logistic_reg <- function(object, call = rlang::caller_env()) {
139139

140140
args <- lapply(object$args, rlang::eval_tidy)
141141

142-
if (all(is.numeric(args$penalty)) && any(args$penalty < 0))
143-
rlang::abort("The amount of regularization should be >= 0.")
144-
if (is.numeric(args$mixture) && (args$mixture < 0 | args$mixture > 1))
145-
rlang::abort("The mixture proportion should be within [0,1].")
146-
if (is.numeric(args$mixture) && length(args$mixture) > 1)
147-
rlang::abort("Only one value of `mixture` is allowed.")
142+
check_number_decimal(args$mixture, min = 0, max = 1, allow_null = TRUE, call = call, arg = "mixture")
143+
check_number_decimal(args$penalty, min = 0, allow_null = TRUE, call = call, arg = "penalty")
148144

149145
if (object$engine == "LiblineaR") {
150-
if(is.numeric(args$mixture) && !args$mixture %in% 0:1)
151-
rlang::abort(c("For the LiblineaR engine, mixture must be 0 or 1.",
152-
"Choose a pure ridge model with `mixture = 0`.",
153-
"Choose a pure lasso model with `mixture = 1`.",
154-
"The Liblinear engine does not support other values."))
155-
if(all(is.numeric(args$penalty)) && !all(args$penalty > 0))
156-
rlang::abort("For the LiblineaR engine, penalty must be > 0.")
146+
if (is.numeric(args$mixture) && !args$mixture %in% 0:1) {
147+
cli::cli_abort(
148+
c("x" = "For the {.pkg LiblineaR} engine, mixture must be 0 or 1, \\
149+
not {args$mixture}.",
150+
"i" = "Choose a pure ridge model with {.code mixture = 0} or \\
151+
a pure lasso model with {.code mixture = 1}.",
152+
"!" = "The {.pkg Liblinear} engine does not support other values."),
153+
call = call
154+
)
155+
}
156+
157+
if ((!is.null(args$penalty)) && args$penalty == 0) {
158+
cli::cli_abort(
159+
"For the {.pkg LiblineaR} engine, {.arg penalty} must be {.code > 0}, \\
160+
not 0.",
161+
call = call
162+
)
163+
}
157164
}
158165

159166
invisible(object)

R/mars.R

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,20 +105,13 @@ translate.mars <- function(x, engine = x$engine, ...) {
105105
# ------------------------------------------------------------------------------
106106

107107
#' @export
108-
check_args.mars <- function(object) {
108+
check_args.mars <- function(object, call = rlang::caller_env()) {
109109

110110
args <- lapply(object$args, rlang::eval_tidy)
111111

112-
if (is.numeric(args$prod_degree) && args$prod_degree < 0)
113-
rlang::abort("`prod_degree` should be >= 1.")
114-
115-
if (is.numeric(args$num_terms) && args$num_terms < 0)
116-
rlang::abort("`num_terms` should be >= 1.")
117-
118-
if (!is_varying(args$prune_method) &&
119-
!is.null(args$prune_method) &&
120-
!is.character(args$prune_method))
121-
rlang::abort("`prune_method` should be a single string value.")
112+
check_number_whole(args$prod_degree, min = 1, allow_null = TRUE, call = call, arg = "prod_degree")
113+
check_number_whole(args$num_terms, min = 1, allow_null = TRUE, call = call, arg = "num_terms")
114+
check_string(args$prune_method, allow_empty = FALSE, allow_null = TRUE, call = call, arg = "prune_method")
122115

123116
invisible(object)
124117
}

0 commit comments

Comments
 (0)