Skip to content

Commit cb0613f

Browse files
authored
Merge branch 'master' into multinom_reg_pred
2 parents 9001b00 + ae66258 commit cb0613f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+627
-39
lines changed

DESCRIPTION

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
Package: parsnip
2-
Version: 0.0.3.9001
2+
Version: 0.0.3.9002
33
Title: A Common API to Modeling and Analysis Functions
44
Description: A common interface is provided to allow users to specify a model without having to remember the different argument names across different functions or computational engines (e.g. 'R', 'Spark', 'Stan', etc).
55
Authors@R: c(
66
person(given = "Max", family = "Kuhn", email = "[email protected]", role = c("aut", "cre")),
77
person(given = "Davis", family = "Vaughan", email = "[email protected]", role = c("aut")),
88
person("RStudio", role = "cph"))
99
Maintainer: Max Kuhn <[email protected]>
10-
URL: https://tidymodels.github.io/parsnip
10+
URL: https://tidymodels.github.io/parsnip, https://github.com/tidymodels/parsnip
1111
BugReports: https://github.com/tidymodels/parsnip/issues
1212
License: GPL-2
1313
Encoding: UTF-8
@@ -28,6 +28,7 @@ Imports:
2828
stats,
2929
tidyr,
3030
globals,
31+
prettyunits,
3132
vctrs (>= 0.2.0)
3233
Roxygen: list(markdown = TRUE)
3334
RoxygenNote: 6.1.99.9001

NEWS.md

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
# parsnip 0.0.3.9001
22

3-
* [A bug](https://github.com/tidymodels/parsnip/issues/222) was fixed standardizing
4-
the output column types of `multi_predict` and `predict` for `multinom_reg`.
3+
## New Features
4+
5+
* The time elapsed during model fitting is stored in the `$elapsed` slot of the
6+
parsnip model object, and is printed when the model object is printed.
57

68
* Some default parameter ranges were updated for SVM, KNN, and MARS models.
79

10+
## Fixes
11+
* [A bug](https://github.com/tidymodels/parsnip/issues/222) was fixed standardizing
12+
the output column types of `multi_predict` and `predict` for `multinom_reg`.
13+
814
* [A bug](https://github.com/tidymodels/parsnip/issues/208) was fixed related to using data descriptors and `fit_xy()`.
915

1016
* A bug was fixed related to the column names generated by `multi_predict()`. The top-level tibble will always have a column named `.pred` and this list column contains tibbles across sub-models. The column names for these sub-model tibbles will have names consistent with `predict()` (which was previously incorrect). See [43c15db](https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2).
1117

18+
* The model `udpate()` methods gained a `parameters` argument for cases when the parameters are contained in a tibble or list.
19+
20+
# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed standardizing the column names of `nnet` class probability predictions.
21+
22+
1223
# parsnip 0.0.3.1
1324

1425
Test case update due to CRAN running extra tests [(#202)](https://github.com/tidymodels/parsnip/issues/202)

R/boost_tree.R

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ print.boost_tree <- function(x, ...) {
148148

149149
#' @export
150150
#' @param object A boosted tree model specification.
151+
#' @param parameters A 1-row tibble or named list with _main_
152+
#' parameters to update. If the individual arguments are used,
153+
#' these will supersede the values in `parameters`. Also, using
154+
#' engine arguments in this object will result in an error.
151155
#' @param ... Not used for `update()`.
152156
#' @param fresh A logical for whether the arguments should be
153157
#' modified in-place of or replaced wholesale.
@@ -157,17 +161,31 @@ print.boost_tree <- function(x, ...) {
157161
#' model
158162
#' update(model, mtry = 1)
159163
#' update(model, mtry = 1, fresh = TRUE)
164+
#'
165+
#' param_values <- tibble::tibble(mtry = 10, tree_depth = 5)
166+
#'
167+
#' model %>% update(param_values)
168+
#' model %>% update(param_values, mtry = 3)
169+
#'
170+
#' param_values$verbose <- 0
171+
#' # Fails due to engine argument
172+
#' # model %>% update(param_values)
160173
#' @method update boost_tree
161174
#' @rdname boost_tree
162175
#' @export
163176
update.boost_tree <-
164177
function(object,
178+
parameters = NULL,
165179
mtry = NULL, trees = NULL, min_n = NULL,
166180
tree_depth = NULL, learn_rate = NULL,
167181
loss_reduction = NULL, sample_size = NULL,
168182
fresh = FALSE, ...) {
169183
update_dot_check(...)
170184

185+
if (!is.null(parameters)) {
186+
parameters <- check_final_param(parameters)
187+
}
188+
171189
args <- list(
172190
mtry = enquo(mtry),
173191
trees = enquo(trees),
@@ -178,6 +196,8 @@ update.boost_tree <-
178196
sample_size = enquo(sample_size)
179197
)
180198

199+
args <- update_main_parameters(args, parameters)
200+
181201
# TODO make these blocks into a function and document well
182202
if (fresh) {
183203
object$args <- args

R/decision_tree.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,22 @@ print.decision_tree <- function(x, ...) {
137137
#' @export
138138
update.decision_tree <-
139139
function(object,
140+
parameters = NULL,
140141
cost_complexity = NULL, tree_depth = NULL, min_n = NULL,
141142
fresh = FALSE, ...) {
142143
update_dot_check(...)
144+
145+
if (!is.null(parameters)) {
146+
parameters <- check_final_param(parameters)
147+
}
143148
args <- list(
144149
cost_complexity = enquo(cost_complexity),
145150
tree_depth = enquo(tree_depth),
146151
min_n = enquo(min_n)
147152
)
148153

154+
args <- update_main_parameters(args, parameters)
155+
149156
if (fresh) {
150157
object$args <- args
151158
} else {

R/fit.R

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ check_xy_interface <- function(x, y, cl, model) {
367367
#' @export
368368
print.model_fit <- function(x, ...) {
369369
cat("parsnip model object\n\n")
370+
cat("Fit in: ", prettyunits::pretty_sec(x$elapsed[["elapsed"]]))
370371

371372
if (inherits(x$fit, "try-error")) {
372373
cat("Model fit failed with error:\n", x$fit, "\n")

R/fit_helpers.R

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,17 @@ form_form <-
5252
spec = object
5353
)
5454

55-
res$fit <- eval_mod(
56-
fit_call,
57-
capture = control$verbosity == 0,
58-
catch = control$catch,
59-
env = env,
60-
...
55+
elapsed <- system.time(
56+
res$fit <- eval_mod(
57+
fit_call,
58+
capture = control$verbosity == 0,
59+
catch = control$catch,
60+
env = env,
61+
...
62+
)
6163
)
6264
res$preproc <- list(y_var = all.vars(env$formula[[2]]))
65+
res$elapsed <- elapsed
6366
res
6467
}
6568

@@ -107,19 +110,24 @@ xy_xy <- function(object, env, control, target = "none", ...) {
107110

108111
res <- list(lvl = levels(env$y), spec = object)
109112

110-
res$fit <- eval_mod(
111-
fit_call,
112-
capture = control$verbosity == 0,
113-
catch = control$catch,
114-
env = env,
115-
...
113+
114+
elapsed <- system.time(
115+
res$fit <- eval_mod(
116+
fit_call,
117+
capture = control$verbosity == 0,
118+
catch = control$catch,
119+
env = env,
120+
...
121+
)
116122
)
123+
117124
if (is.vector(env$y)) {
118125
y_name <- character(0)
119126
} else {
120127
y_name <- colnames(env$y)
121128
}
122129
res$preproc <- list(y_var = y_name)
130+
res$elapsed <- elapsed
123131
res
124132
}
125133

R/linear_reg.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,14 +169,21 @@ translate.linear_reg <- function(x, engine = x$engine, ...) {
169169
#' @export
170170
update.linear_reg <-
171171
function(object,
172+
parameters = NULL,
172173
penalty = NULL, mixture = NULL,
173174
fresh = FALSE, ...) {
174175
update_dot_check(...)
176+
177+
if (!is.null(parameters)) {
178+
parameters <- check_final_param(parameters)
179+
}
175180
args <- list(
176181
penalty = enquo(penalty),
177182
mixture = enquo(mixture)
178183
)
179184

185+
args <- update_main_parameters(args, parameters)
186+
180187
if (fresh) {
181188
object$args <- args
182189
} else {

R/logistic_reg.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,21 @@ translate.logistic_reg <- translate.linear_reg
154154
#' @export
155155
update.logistic_reg <-
156156
function(object,
157+
parameters = NULL,
157158
penalty = NULL, mixture = NULL,
158159
fresh = FALSE, ...) {
159160
update_dot_check(...)
161+
162+
if (!is.null(parameters)) {
163+
parameters <- check_final_param(parameters)
164+
}
160165
args <- list(
161166
penalty = enquo(penalty),
162167
mixture = enquo(mixture)
163168
)
164169

170+
args <- update_main_parameters(args, parameters)
171+
165172
if (fresh) {
166173
object$args <- args
167174
} else {

R/mars.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,16 +106,23 @@ print.mars <- function(x, ...) {
106106
#' @export
107107
update.mars <-
108108
function(object,
109+
parameters = NULL,
109110
num_terms = NULL, prod_degree = NULL, prune_method = NULL,
110111
fresh = FALSE, ...) {
111112
update_dot_check(...)
112113

114+
if (!is.null(parameters)) {
115+
parameters <- check_final_param(parameters)
116+
}
117+
113118
args <- list(
114119
num_terms = enquo(num_terms),
115120
prod_degree = enquo(prod_degree),
116121
prune_method = enquo(prune_method)
117122
)
118123

124+
args <- update_main_parameters(args, parameters)
125+
119126
if (fresh) {
120127
object$args <- args
121128
} else {

R/misc.R

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,51 @@ terms_y <- function(x) {
232232
y_expr <- att$predvars[[resp_ind + 1]]
233233
all.vars(y_expr)
234234
}
235+
236+
237+
# ------------------------------------------------------------------------------
238+
239+
check_final_param <- function(x) {
240+
if (is.null(x)) {
241+
return(invisible(x))
242+
}
243+
if (!is.list(x) & !tibble::is_tibble(x)) {
244+
rlang::abort("The parameter object should be a list or tibble")
245+
}
246+
if (tibble::is_tibble(x) && nrow(x) > 1) {
247+
rlang::abort("The parameter tibble should have a single row.")
248+
}
249+
if (tibble::is_tibble(x)) {
250+
x <- as.list(x)
251+
}
252+
if (length(names) == 0 || any(names(x) == "")) {
253+
rlang::abort("All values in `parameters` should have a name.")
254+
}
255+
256+
invisible(x)
257+
}
258+
259+
update_main_parameters <- function(args, param) {
260+
261+
if (length(param) == 0) {
262+
return(args)
263+
}
264+
if (length(args) == 0) {
265+
return(param)
266+
}
267+
268+
# In case an engine argument is included:
269+
has_extra_args <- !(names(param) %in% names(args))
270+
extra_args <- names(param)[has_extra_args]
271+
if (any(has_extra_args)) {
272+
rlang::abort(
273+
paste("At least one argument is not a main argument:",
274+
paste0("`", extra_args, "`", collapse = ", "))
275+
)
276+
}
277+
param <- param[!has_extra_args]
278+
279+
280+
281+
args <- utils::modifyList(args, param)
282+
}

R/mlp.R

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,16 @@ print.mlp <- function(x, ...) {
139139
#' @export
140140
update.mlp <-
141141
function(object,
142+
parameters = NULL,
142143
hidden_units = NULL, penalty = NULL, dropout = NULL,
143144
epochs = NULL, activation = NULL,
144145
fresh = FALSE, ...) {
145146
update_dot_check(...)
147+
148+
if (!is.null(parameters)) {
149+
parameters <- check_final_param(parameters)
150+
}
151+
146152
args <- list(
147153
hidden_units = enquo(hidden_units),
148154
penalty = enquo(penalty),
@@ -151,6 +157,8 @@ update.mlp <-
151157
activation = enquo(activation)
152158
)
153159

160+
args <- update_main_parameters(args, parameters)
161+
154162
# TODO make these blocks into a function and document well
155163
if (fresh) {
156164
object$args <- args
@@ -381,7 +389,7 @@ nnet_softmax <- function(results, object) {
381389

382390
results <- apply(results, 1, function(x) exp(x)/sum(exp(x)))
383391
results <- t(results)
384-
names(results) <- paste0(".pred_", object$lvl)
392+
colnames(results) <- object$lvl
385393
results <- as_tibble(results)
386394
results
387395
}

R/multinom_reg.R

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,21 @@ translate.multinom_reg <- translate.linear_reg
137137
#' @export
138138
update.multinom_reg <-
139139
function(object,
140+
parameters = NULL,
140141
penalty = NULL, mixture = NULL,
141142
fresh = FALSE, ...) {
142143
update_dot_check(...)
144+
145+
if (!is.null(parameters)) {
146+
parameters <- check_final_param(parameters)
147+
}
143148
args <- list(
144149
penalty = enquo(penalty),
145150
mixture = enquo(mixture)
146151
)
147152

153+
args <- update_main_parameters(args, parameters)
154+
148155
if (fresh) {
149156
object$args <- args
150157
} else {

R/nearest_neighbor.R

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,17 +106,25 @@ print.nearest_neighbor <- function(x, ...) {
106106
#' @export
107107
#' @inheritParams update.boost_tree
108108
update.nearest_neighbor <- function(object,
109+
parameters = NULL,
109110
neighbors = NULL,
110111
weight_func = NULL,
111112
dist_power = NULL,
112113
fresh = FALSE, ...) {
113114
update_dot_check(...)
115+
116+
if (!is.null(parameters)) {
117+
parameters <- check_final_param(parameters)
118+
}
119+
114120
args <- list(
115121
neighbors = enquo(neighbors),
116122
weight_func = enquo(weight_func),
117123
dist_power = enquo(dist_power)
118124
)
119125

126+
args <- update_main_parameters(args, parameters)
127+
120128
if (fresh) {
121129
object$args <- args
122130
} else {

0 commit comments

Comments
 (0)