Skip to content

Commit f6b7972

Browse files
authored
export xgb_pred() (#688)
* export `xgb_pred()` * rename function and arg * update news
1 parent 0e9f4ba commit f6b7972

File tree

6 files changed

+38
-24
lines changed

6 files changed

+38
-24
lines changed

DESCRIPTION

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,4 +89,4 @@ Config/rcmdcheck/ignore-inconsequential-notes: true
8989
Encoding: UTF-8
9090
LazyData: true
9191
Roxygen: list(markdown = TRUE)
92-
RoxygenNote: 7.1.2
92+
RoxygenNote: 7.1.2.9000

NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,7 @@ export(update_main_parameters)
294294
export(update_model_info_file)
295295
export(varying)
296296
export(varying_args)
297+
export(xgb_predict)
297298
export(xgb_train)
298299
importFrom(dplyr,arrange)
299300
importFrom(dplyr,bind_cols)

NEWS.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# parsnip (development version)
22

3+
* Exported `xgb_predict()` which wraps xgboost's `predict()` method for use with parsnip extension packages (#688).
4+
5+
36
# parsnip 0.2.1
47

58
* Fixed a major bug in spark models induced in the previous version (#671).
69

710
* Updated the parsnip add-in with new models and engines.
811

912
* Updated parameter ranges for some `tunable()` methods and added a missing engine argument for brulee models.
13+
1014
* Added information about how to install the mixOmics package for PLS models (#680)
1115

1216

R/boost_tree.R

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ check_args.boost_tree <- function(object) {
217217

218218
#' Boosted trees via xgboost
219219
#'
220-
#' `xgb_train` is a wrapper for `xgboost` tree-based models where all of the
221-
#' model arguments are in the main function.
220+
#' `xgb_train()` and `xgb_predict()` are wrappers for `xgboost` tree-based
221+
#' models where all of the model arguments are in the main function.
222222
#'
223223
#' @param x A data frame or matrix of predictors
224224
#' @param y A vector (factor or numeric) or matrix (numeric) of outcome data.
@@ -251,7 +251,7 @@ check_args.boost_tree <- function(object) {
251251
#' @param event_level For binary classification, this is a single string of either
252252
#' `"first"` or `"second"` to pass along describing which level of the outcome
253253
#' should be considered the "event".
254-
#' @param ... Other options to pass to `xgb.train`.
254+
#' @param ... Other options to pass to `xgb.train()` or xgboost's method for `predict()`.
255255
#' @return A fitted `xgboost` object.
256256
#' @keywords internal
257257
#' @export
@@ -383,13 +383,17 @@ maybe_proportion <- function(x, nm) {
383383
}
384384
}
385385

386-
xgb_pred <- function(object, newdata, ...) {
387-
if (!inherits(newdata, "xgb.DMatrix")) {
388-
newdata <- maybe_matrix(newdata)
389-
newdata <- xgboost::xgb.DMatrix(data = newdata, missing = NA)
386+
#' @rdname xgb_train
387+
#' @param new_data A rectangular data object, such as a data frame.
388+
#' @keywords internal
389+
#' @export
390+
xgb_predict <- function(object, new_data, ...) {
391+
if (!inherits(new_data, "xgb.DMatrix")) {
392+
new_data <- maybe_matrix(new_data)
393+
new_data <- xgboost::xgb.DMatrix(data = new_data, missing = NA)
390394
}
391395

392-
res <- predict(object, newdata, ...)
396+
res <- predict(object, new_data, ...)
393397

394398
x <- switch(
395399
object$params$objective,
@@ -482,9 +486,9 @@ multi_predict._xgb.Booster <-
482486
}
483487

484488
xgb_by_tree <- function(tree, object, new_data, type, ...) {
485-
pred <- xgb_pred(
489+
pred <- xgb_predict(
486490
object$fit,
487-
newdata = new_data,
491+
new_data = new_data,
488492
iterationrange = c(1, tree + 1),
489493
ntreelimit = NULL
490494
)

R/boost_tree_data.R

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ set_pred(
108108
value = list(
109109
pre = NULL,
110110
post = NULL,
111-
func = c(fun = "xgb_pred"),
112-
args = list(object = quote(object$fit), newdata = quote(new_data))
111+
func = c(fun = "xgb_predict"),
112+
args = list(object = quote(object$fit), new_data = quote(new_data))
113113
)
114114
)
115115

@@ -121,8 +121,8 @@ set_pred(
121121
value = list(
122122
pre = NULL,
123123
post = NULL,
124-
func = c(fun = "xgb_pred"),
125-
args = list(object = quote(object$fit), newdata = quote(new_data))
124+
func = c(fun = "xgb_predict"),
125+
args = list(object = quote(object$fit), new_data = quote(new_data))
126126
)
127127
)
128128

@@ -170,8 +170,8 @@ set_pred(
170170
}
171171
x
172172
},
173-
func = c(pkg = NULL, fun = "xgb_pred"),
174-
args = list(object = quote(object$fit), newdata = quote(new_data))
173+
func = c(pkg = NULL, fun = "xgb_predict"),
174+
args = list(object = quote(object$fit), new_data = quote(new_data))
175175
)
176176
)
177177

@@ -196,8 +196,8 @@ set_pred(
196196
colnames(x) <- object$lvl
197197
x
198198
},
199-
func = c(pkg = NULL, fun = "xgb_pred"),
200-
args = list(object = quote(object$fit), newdata = quote(new_data))
199+
func = c(pkg = NULL, fun = "xgb_predict"),
200+
args = list(object = quote(object$fit), new_data = quote(new_data))
201201
)
202202
)
203203

@@ -209,8 +209,8 @@ set_pred(
209209
value = list(
210210
pre = NULL,
211211
post = NULL,
212-
func = c(fun = "xgb_pred"),
213-
args = list(object = quote(object$fit), newdata = quote(new_data))
212+
func = c(fun = "xgb_predict"),
213+
args = list(object = quote(object$fit), new_data = quote(new_data))
214214
)
215215
)
216216

man/xgb_train.Rd

Lines changed: 8 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)