Skip to content

Commit 33079a3

Browse files
authored
Merge branch 'main' into feature/case-weights
2 parents 69c7c14 + 5e79150 commit 33079a3

20 files changed

+335
-123
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 0.2.1.9000
3+
Version: 0.2.1.9001
44
Authors@R: c(
55
person("Max", "Kuhn", , "[email protected]", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "[email protected]", role = "aut"),

NAMESPACE

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,13 +171,15 @@ export(bag_tree)
171171
export(bart)
172172
export(bartMachine_interval_calc)
173173
export(boost_tree)
174+
export(cforest_train)
174175
export(check_empty_ellipse)
175176
export(check_final_param)
176177
export(check_model_doesnt_exist)
177178
export(check_model_exists)
178179
export(contr_one_hot)
179180
export(control_parsnip)
180181
export(convert_stan_interval)
182+
export(ctree_train)
181183
export(cubist_rules)
182184
export(dbart_predict_calc)
183185
export(decision_tree)
@@ -220,6 +222,7 @@ export(make_classes)
220222
export(make_engine_list)
221223
export(make_seealso_list)
222224
export(mars)
225+
export(max_mtry_formula)
223226
export(maybe_data_frame)
224227
export(maybe_matrix)
225228
export(min_cols)
@@ -297,6 +300,7 @@ export(update_main_parameters)
297300
export(update_model_info_file)
298301
export(varying)
299302
export(varying_args)
303+
export(xgb_predict)
300304
export(xgb_train)
301305
importFrom(dplyr,arrange)
302306
importFrom(dplyr,bind_cols)

NEWS.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
# parsnip (development version)
22

3+
34
* Enable the use of case weights for models that support them.
45

56
* Added a `glm_grouped()` function to convert long data to the grouped format required by `glm()` for logistic regression.
67

78
* `show_model_info()` now indicates which models can utilize case weights.
89

10+
* `xgb_train()` now allows for case weights
11+
12+
* Added `ctree_train()` and `cforest_train()` wrappers for the functions in the partykit package. Engines for these will be added to other parsnip extension packages.
13+
14+
* Exported `xgb_predict()` which wraps xgboost's `predict()` method for use with parsnip extension packages (#688).
15+
16+
917
# parsnip 0.2.1
1018

1119
* Fixed a major bug in spark models induced in the previous version (#671).
1220

1321
* Updated the parsnip add-in with new models and engines.
1422

1523
* Updated parameter ranges for some `tunable()` methods and added a missing engine argument for brulee models.
24+
1625
* Added information about how to install the mixOmics package for PLS models (#680)
1726

1827

R/aaa_models.R

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,12 +260,6 @@ check_mode_with_no_engine <- function(cls, mode) {
260260
}
261261
}
262262

263-
check_engine_val <- function(eng) {
264-
if (rlang::is_missing(eng) || length(eng) != 1 || !is.character(eng))
265-
rlang::abort("Please supply a character string for an engine (e.g. `'lm'`).")
266-
invisible(NULL)
267-
}
268-
269263
check_arg_val <- function(arg) {
270264
if (rlang::is_missing(arg) || length(arg) != 1 || !is.character(arg))
271265
rlang::abort("Please supply a character string for the argument.")

R/boost_tree.R

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,13 @@ check_args.boost_tree <- function(object) {
213213
invisible(object)
214214
}
215215

216+
216217
# xgboost helpers --------------------------------------------------------------
217218

218219
#' Boosted trees via xgboost
219220
#'
220-
#' `xgb_train` is a wrapper for `xgboost` tree-based models where all of the
221-
#' model arguments are in the main function.
221+
#' `xgb_train()` and `xgb_predict()` are wrappers for `xgboost` tree-based
222+
#' models where all of the model arguments are in the main function.
222223
#'
223224
#' @param x A data frame or matrix of predictors
224225
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
@@ -251,16 +252,16 @@ check_args.boost_tree <- function(object) {
251252
#' @param event_level For binary classification, this is a single string of either
252253
#' `"first"` or `"second"` to pass along describing which level of the outcome
253254
#' should be considered the "event".
254-
#' @param ... Other options to pass to `xgb.train`.
255+
#' @param ... Other options to pass to `xgb.train()` or xgboost's method for `predict()`.
255256
#' @return A fitted `xgboost` object.
256257
#' @keywords internal
257258
#' @export
258259
xgb_train <- function(
259-
x, y,
260-
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
261-
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
262-
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
263-
event_level = c("first", "second"), weights = NULL, ...) {
260+
x, y, weights = NULL,
261+
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
262+
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
263+
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
264+
event_level = c("first", "second"), ...) {
264265

265266
event_level <- rlang::arg_match(event_level, c("first", "second"))
266267
others <- list(...)
@@ -387,13 +388,17 @@ maybe_proportion <- function(x, nm) {
387388
}
388389
}
389390

390-
xgb_pred <- function(object, newdata, ...) {
391-
if (!inherits(newdata, "xgb.DMatrix")) {
392-
newdata <- maybe_matrix(newdata)
393-
newdata <- xgboost::xgb.DMatrix(data = newdata, missing = NA)
391+
#' @rdname xgb_train
392+
#' @param new_data A rectangular data object, such as a data frame.
393+
#' @keywords internal
394+
#' @export
395+
xgb_predict <- function(object, new_data, ...) {
396+
if (!inherits(new_data, "xgb.DMatrix")) {
397+
new_data <- maybe_matrix(new_data)
398+
new_data <- xgboost::xgb.DMatrix(data = new_data, missing = NA)
394399
}
395400

396-
res <- predict(object, newdata, ...)
401+
res <- predict(object, new_data, ...)
397402

398403
x <- switch(
399404
object$params$objective,
@@ -470,6 +475,7 @@ get_event_level <- function(model_spec){
470475
event_level
471476
}
472477

478+
473479
#' @export
474480
#' @rdname multi_predict
475481
#' @param trees An integer vector for the number of trees in the ensemble.
@@ -500,9 +506,9 @@ multi_predict._xgb.Booster <-
500506
}
501507

502508
xgb_by_tree <- function(tree, object, new_data, type, ...) {
503-
pred <- xgb_pred(
509+
pred <- xgb_predict(
504510
object$fit,
505-
newdata = new_data,
511+
new_data = new_data,
506512
iterationrange = c(1, tree + 1),
507513
ntreelimit = NULL
508514
)

R/boost_tree_data.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ set_pred(
108108
value = list(
109109
pre = NULL,
110110
post = NULL,
111-
func = c(fun = "xgb_pred"),
112-
args = list(object = quote(object$fit), newdata = quote(new_data))
111+
func = c(fun = "xgb_predict"),
112+
args = list(object = quote(object$fit), new_data = quote(new_data))
113113
)
114114
)
115115

@@ -121,8 +121,8 @@ set_pred(
121121
value = list(
122122
pre = NULL,
123123
post = NULL,
124-
func = c(fun = "xgb_pred"),
125-
args = list(object = quote(object$fit), newdata = quote(new_data))
124+
func = c(fun = "xgb_predict"),
125+
args = list(object = quote(object$fit), new_data = quote(new_data))
126126
)
127127
)
128128

@@ -170,8 +170,8 @@ set_pred(
170170
}
171171
x
172172
},
173-
func = c(pkg = NULL, fun = "xgb_pred"),
174-
args = list(object = quote(object$fit), newdata = quote(new_data))
173+
func = c(pkg = NULL, fun = "xgb_predict"),
174+
args = list(object = quote(object$fit), new_data = quote(new_data))
175175
)
176176
)
177177

@@ -196,8 +196,8 @@ set_pred(
196196
colnames(x) <- object$lvl
197197
x
198198
},
199-
func = c(pkg = NULL, fun = "xgb_pred"),
200-
args = list(object = quote(object$fit), newdata = quote(new_data))
199+
func = c(pkg = NULL, fun = "xgb_predict"),
200+
args = list(object = quote(object$fit), new_data = quote(new_data))
201201
)
202202
)
203203

@@ -209,8 +209,8 @@ set_pred(
209209
value = list(
210210
pre = NULL,
211211
post = NULL,
212-
func = c(fun = "xgb_pred"),
213-
args = list(object = quote(object$fit), newdata = quote(new_data))
212+
func = c(fun = "xgb_predict"),
213+
args = list(object = quote(object$fit), new_data = quote(new_data))
214214
)
215215
)
216216

R/descriptors.R

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -320,18 +320,6 @@ has_exprs <- function(x) {
320320
is_symbolic(x)
321321
}
322322

323-
make_descr <- function(object) {
324-
if (length(object$args) > 0)
325-
expr_main <- map_lgl(object$args, has_exprs)
326-
else
327-
expr_main <- FALSE
328-
if (length(object$eng_args) > 0)
329-
expr_eng_args <- map_lgl(object$eng_args, has_exprs)
330-
else
331-
expr_eng_args <- FALSE
332-
any(expr_main) | any(expr_eng_args)
333-
}
334-
335323
# Locate descriptors -----------------------------------------------------------
336324

337325
# take a model spec, see if any require descriptors

R/engine_docs.R

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -289,22 +289,6 @@ sort_c <- function(x) {
289289
withr::with_collate("C", sort(x))
290290
}
291291

292-
get_sorted_unique_engines <- function(x) {
293-
engines <- x$engine
294-
engines <- unique(engines)
295-
engines <- sort_c(engines)
296-
engines
297-
}
298-
combine_prefix_with_engines <- function(prefix, engines) {
299-
if (length(engines) == 0L) {
300-
engines <- "No engines currently available"
301-
} else {
302-
engines <- glue::glue_collapse(engines, sep = ", ")
303-
}
304-
305-
glue::glue("{prefix} {engines}")
306-
}
307-
308292
# ------------------------------------------------------------------------------
309293

310294
#' Locate and show errors/warnings in engine-specific documentation

R/fit.R

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -390,10 +390,6 @@ inher <- function(x, cls, cl) {
390390

391391
# ------------------------------------------------------------------------------
392392

393-
394-
has_both_or_none <- function(a, b)
395-
(!is.null(a) & is.null(b)) | (is.null(a) & !is.null(b))
396-
397393
check_interface <- function(formula, data, cl, model) {
398394
inher(formula, "formula", cl)
399395
inher(data, c("data.frame", "tbl_spark"), cl)

R/linear_reg_gls.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#' Linear regression via generalized least squares
22
#'
33
#' The `"gls"` engine estimates linear regression for models where the rows of the
4-
#' data are not indpendent.
4+
#' data are not independent.
55
#'
66
#' @includeRmd man/rmd/linear_reg_gls.md details
77
#'

R/misc.R

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ levels_from_formula <- function(f, dat) {
119119
res
120120
}
121121

122-
is_spark <- function(x)
123-
isTRUE(unname(x$method$fit$func["pkg"] == "sparklyr"))
124-
125-
126122
#' @export
127123
#' @keywords internal
128124
#' @rdname add_on_exports
@@ -211,17 +207,6 @@ check_outcome <- function(y, spec) {
211207
invisible(NULL)
212208
}
213209

214-
215-
# Get's a character string of varible names used as the outcome
216-
# in a terms object
217-
terms_y <- function(x) {
218-
att <- attributes(x)
219-
resp_ind <- att$response
220-
y_expr <- att$predvars[[resp_ind + 1]]
221-
all.vars(y_expr)
222-
}
223-
224-
225210
# ------------------------------------------------------------------------------
226211

227212
#' @export

0 commit comments

Comments
 (0)