Skip to content

Commit ae66258

Browse files
authored
Merge pull request #229 from tidymodels/Roxygen-dev
Misc updates
2 parents 3a6d11f + ee97f9e commit ae66258

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+567
-26
lines changed

DESCRIPTION

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
Package: parsnip
2-
Version: 0.0.3.9001
2+
Version: 0.0.3.9002
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(
66
person(given = "Max", family = "Kuhn", email = "[email protected]", role = c("aut", "cre")),
77
person(given = "Davis", family = "Vaughan", email = "[email protected]", role = c("aut")),
88
person("RStudio", role = "cph"))
99
Maintainer: Max Kuhn <[email protected]>
10-
URL: https://tidymodels.github.io/parsnip
10+
URL: https://tidymodels.github.io/parsnip, https://github.com/tidymodels/parsnip
1111
BugReports: https://github.com/tidymodels/parsnip/issues
1212
License: GPL-2
1313
Encoding: UTF-8

NEWS.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ parsnip model object, and is printed when the model object is printed.
1313

1414
* A bug was fixed related to the column names generated by `multi_predict()`. The top-level tibble will always have a column named `.pred` and this list column contains tibbles across sub-models. The column names for these sub-model tibbles will have names consistent with `predict()` (which was previously incorrect). See [43c15db](https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2).
1515

16-
# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed
17-
standardizing the column names of `nnet` class probability predictions.
16+
* The model `udpate()` methods gained a `parameters` argument for cases when the parameters are contained in a tibble or list.
17+
18+
# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed standardizing the column names of `nnet` class probability predictions.
19+
1820

1921
# parsnip 0.0.3.1
2022

R/boost_tree.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ print.boost_tree <- function(x, ...) {
148148

149149
#' @export
150150
#' @param object A boosted tree model specification.
151+
#' @param parameters A 1-row tibble or named list with _main_
152+
#' parameters to update. If the individual arguments are used,
153+
#' these will supersede the values in `parameters`. Also, using
154+
#' engine arguments in this object will result in an error.
151155
#' @param ... Not used for `update()`.
152156
#' @param fresh A logical for whether the arguments should be
153157
#' modified in-place of or replaced wholesale.
@@ -157,17 +161,31 @@ print.boost_tree <- function(x, ...) {
157161
#' model
158162
#' update(model, mtry = 1)
159163
#' update(model, mtry = 1, fresh = TRUE)
164+
#'
165+
#' param_values <- tibble::tibble(mtry = 10, tree_depth = 5)
166+
#'
167+
#' model %>% update(param_values)
168+
#' model %>% update(param_values, mtry = 3)
169+
#'
170+
#' param_values$verbose <- 0
171+
#' # Fails due to engine argument
172+
#' # model %>% update(param_values)
160173
#' @method update boost_tree
161174
#' @rdname boost_tree
162175
#' @export
163176
update.boost_tree <-
164177
function(object,
178+
parameters = NULL,
165179
mtry = NULL, trees = NULL, min_n = NULL,
166180
tree_depth = NULL, learn_rate = NULL,
167181
loss_reduction = NULL, sample_size = NULL,
168182
fresh = FALSE, ...) {
169183
update_dot_check(...)
170184

185+
if (!is.null(parameters)) {
186+
parameters <- check_final_param(parameters)
187+
}
188+
171189
args <- list(
172190
mtry = enquo(mtry),
173191
trees = enquo(trees),
@@ -178,6 +196,8 @@ update.boost_tree <-
178196
sample_size = enquo(sample_size)
179197
)
180198

199+
args <- update_main_parameters(args, parameters)
200+
181201
# TODO make these blocks into a function and document well
182202
if (fresh) {
183203
object$args <- args

R/decision_tree.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,22 @@ print.decision_tree <- function(x, ...) {
137137
#' @export
138138
update.decision_tree <-
139139
function(object,
140+
parameters = NULL,
140141
cost_complexity = NULL, tree_depth = NULL, min_n = NULL,
141142
fresh = FALSE, ...) {
142143
update_dot_check(...)
144+
145+
if (!is.null(parameters)) {
146+
parameters <- check_final_param(parameters)
147+
}
143148
args <- list(
144149
cost_complexity = enquo(cost_complexity),
145150
tree_depth = enquo(tree_depth),
146151
min_n = enquo(min_n)
147152
)
148153

154+
args <- update_main_parameters(args, parameters)
155+
149156
if (fresh) {
150157
object$args <- args
151158
} else {

R/linear_reg.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,21 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
169169
#' @export
170170
update.linear_reg <-
171171
function(object,
172+
parameters = NULL,
172173
penalty = NULL, mixture = NULL,
173174
fresh = FALSE, ...) {
174175
update_dot_check(...)
176+
177+
if (!is.null(parameters)) {
178+
parameters <- check_final_param(parameters)
179+
}
175180
args <- list(
176181
penalty = enquo(penalty),
177182
mixture = enquo(mixture)
178183
)
179184

185+
args <- update_main_parameters(args, parameters)
186+
180187
if (fresh) {
181188
object$args <- args
182189
} else {

R/logistic_reg.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,21 @@ translate.logistic_reg <- translate.linear_reg
154154
#' @export
155155
update.logistic_reg <-
156156
function(object,
157+
parameters = NULL,
157158
penalty = NULL, mixture = NULL,
158159
fresh = FALSE, ...) {
159160
update_dot_check(...)
161+
162+
if (!is.null(parameters)) {
163+
parameters <- check_final_param(parameters)
164+
}
160165
args <- list(
161166
penalty = enquo(penalty),
162167
mixture = enquo(mixture)
163168
)
164169

170+
args <- update_main_parameters(args, parameters)
171+
165172
if (fresh) {
166173
object$args <- args
167174
} else {

R/mars.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,23 @@ print.mars <- function(x, ...) {
106106
#' @export
107107
update.mars <-
108108
function(object,
109+
parameters = NULL,
109110
num_terms = NULL, prod_degree = NULL, prune_method = NULL,
110111
fresh = FALSE, ...) {
111112
update_dot_check(...)
112113

114+
if (!is.null(parameters)) {
115+
parameters <- check_final_param(parameters)
116+
}
117+
113118
args <- list(
114119
num_terms = enquo(num_terms),
115120
prod_degree = enquo(prod_degree),
116121
prune_method = enquo(prune_method)
117122
)
118123

124+
args <- update_main_parameters(args, parameters)
125+
119126
if (fresh) {
120127
object$args <- args
121128
} else {

R/misc.R

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,51 @@ terms_y <- function(x) {
232232
y_expr <- att$predvars[[resp_ind + 1]]
233233
all.vars(y_expr)
234234
}
235+
236+
237+
# ------------------------------------------------------------------------------
238+
239+
check_final_param <- function(x) {
240+
if (is.null(x)) {
241+
return(invisible(x))
242+
}
243+
if (!is.list(x) & !tibble::is_tibble(x)) {
244+
rlang::abort("The parameter object should be a list or tibble")
245+
}
246+
if (tibble::is_tibble(x) && nrow(x) > 1) {
247+
rlang::abort("The parameter tibble should have a single row.")
248+
}
249+
if (tibble::is_tibble(x)) {
250+
x <- as.list(x)
251+
}
252+
if (length(names) == 0 || any(names(x) == "")) {
253+
rlang::abort("All values in `parameters` should have a name.")
254+
}
255+
256+
invisible(x)
257+
}
258+
259+
update_main_parameters <- function(args, param) {
260+
261+
if (length(param) == 0) {
262+
return(args)
263+
}
264+
if (length(args) == 0) {
265+
return(param)
266+
}
267+
268+
# In case an engine argument is included:
269+
has_extra_args <- !(names(param) %in% names(args))
270+
extra_args <- names(param)[has_extra_args]
271+
if (any(has_extra_args)) {
272+
rlang::abort(
273+
paste("At least one argument is not a main argument:",
274+
paste0("`", extra_args, "`", collapse = ", "))
275+
)
276+
}
277+
param <- param[!has_extra_args]
278+
279+
280+
281+
args <- utils::modifyList(args, param)
282+
}

R/mlp.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,16 @@ print.mlp <- function(x, ...) {
139139
#' @export
140140
update.mlp <-
141141
function(object,
142+
parameters = NULL,
142143
hidden_units = NULL, penalty = NULL, dropout = NULL,
143144
epochs = NULL, activation = NULL,
144145
fresh = FALSE, ...) {
145146
update_dot_check(...)
147+
148+
if (!is.null(parameters)) {
149+
parameters <- check_final_param(parameters)
150+
}
151+
146152
args <- list(
147153
hidden_units = enquo(hidden_units),
148154
penalty = enquo(penalty),
@@ -151,6 +157,8 @@ update.mlp <-
151157
activation = enquo(activation)
152158
)
153159

160+
args <- update_main_parameters(args, parameters)
161+
154162
# TODO make these blocks into a function and document well
155163
if (fresh) {
156164
object$args <- args

R/multinom_reg.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,21 @@ translate.multinom_reg <- translate.linear_reg
137137
#' @export
138138
update.multinom_reg <-
139139
function(object,
140+
parameters = NULL,
140141
penalty = NULL, mixture = NULL,
141142
fresh = FALSE, ...) {
142143
update_dot_check(...)
144+
145+
if (!is.null(parameters)) {
146+
parameters <- check_final_param(parameters)
147+
}
143148
args <- list(
144149
penalty = enquo(penalty),
145150
mixture = enquo(mixture)
146151
)
147152

153+
args <- update_main_parameters(args, parameters)
154+
148155
if (fresh) {
149156
object$args <- args
150157
} else {

R/nearest_neighbor.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,25 @@ print.nearest_neighbor <- function(x, ...) {
106106
#' @export
107107
#' @inheritParams update.boost_tree
108108
update.nearest_neighbor <- function(object,
109+
parameters = NULL,
109110
neighbors = NULL,
110111
weight_func = NULL,
111112
dist_power = NULL,
112113
fresh = FALSE, ...) {
113114
update_dot_check(...)
115+
116+
if (!is.null(parameters)) {
117+
parameters <- check_final_param(parameters)
118+
}
119+
114120
args <- list(
115121
neighbors = enquo(neighbors),
116122
weight_func = enquo(weight_func),
117123
dist_power = enquo(dist_power)
118124
)
119125

126+
args <- update_main_parameters(args, parameters)
127+
120128
if (fresh) {
121129
object$args <- args
122130
} else {

R/rand_forest.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,22 @@ print.rand_forest <- function(x, ...) {
140140
#' @export
141141
update.rand_forest <-
142142
function(object,
143+
parameters = NULL,
143144
mtry = NULL, trees = NULL, min_n = NULL,
144145
fresh = FALSE, ...) {
145146
update_dot_check(...)
147+
148+
if (!is.null(parameters)) {
149+
parameters <- check_final_param(parameters)
150+
}
146151
args <- list(
147152
mtry = enquo(mtry),
148153
trees = enquo(trees),
149154
min_n = enquo(min_n)
150155
)
151156

157+
args <- update_main_parameters(args, parameters)
158+
152159
# TODO make these blocks into a function and document well
153160
if (fresh) {
154161
object$args <- args

R/surv_reg.R

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,19 @@ print.surv_reg <- function(x, ...) {
111111
#' @method update surv_reg
112112
#' @rdname surv_reg
113113
#' @export
114-
update.surv_reg <- function(object, dist = NULL, fresh = FALSE, ...) {
114+
update.surv_reg <- function(object, parameters = NULL, dist = NULL, fresh = FALSE, ...) {
115115
update_dot_check(...)
116+
117+
if (!is.null(parameters)) {
118+
parameters <- check_final_param(parameters)
119+
}
120+
116121
args <- list(
117122
dist = enquo(dist)
118123
)
119124

125+
args <- update_main_parameters(args, parameters)
126+
120127
if (fresh) {
121128
object$args <- args
122129
} else {

R/svm_poly.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,18 +106,25 @@ print.svm_poly <- function(x, ...) {
106106
#' @export
107107
update.svm_poly <-
108108
function(object,
109+
parameters = NULL,
109110
cost = NULL, degree = NULL, scale_factor = NULL, margin = NULL,
110111
fresh = FALSE,
111112
...) {
112113
update_dot_check(...)
113114

115+
if (!is.null(parameters)) {
116+
parameters <- check_final_param(parameters)
117+
}
118+
114119
args <- list(
115120
cost = enquo(cost),
116121
degree = enquo(degree),
117122
scale_factor = enquo(scale_factor),
118123
margin = enquo(margin)
119124
)
120125

126+
args <- update_main_parameters(args, parameters)
127+
121128
if (fresh) {
122129
object$args <- args
123130
} else {

R/svm_rbf.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,17 +104,24 @@ print.svm_rbf <- function(x, ...) {
104104
#' @export
105105
update.svm_rbf <-
106106
function(object,
107+
parameters = NULL,
107108
cost = NULL, rbf_sigma = NULL, margin = NULL,
108109
fresh = FALSE,
109110
...) {
110111
update_dot_check(...)
111112

113+
if (!is.null(parameters)) {
114+
parameters <- check_final_param(parameters)
115+
}
116+
112117
args <- list(
113118
cost = enquo(cost),
114119
rbf_sigma = enquo(rbf_sigma),
115120
margin = enquo(margin)
116121
)
117122

123+
args <- update_main_parameters(args, parameters)
124+
118125
if (fresh) {
119126
object$args <- args
120127
} else {

0 commit comments

Comments
 (0)