Skip to content

Commit 1288124

Browse files
authored
changes for case weights and tidymodels/censored#163 (#696)
1 parent c8765c7 commit 1288124

File tree

3 files changed

+37
-15
lines changed

3 files changed

+37
-15
lines changed

NEWS.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
# parsnip (development version)
22

3+
* `xgb_train()` now allows for case weights
4+
35
* Added `ctree_train()` and `cforest_train()` wrappers for the functions in the partykit package. Engines for these will be added to other parsnip extension packages.
46

57
* Exported `xgb_predict()` which wraps xgboost's `predict()` method for use with parsnip extension packages (#688).
68

7-
89
# parsnip 0.2.1
910

1011
* Fixed a major bug in spark models induced in the previous version (#671).

R/boost_tree.R

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ check_args.boost_tree <- function(object) {
213213
invisible(object)
214214
}
215215

216+
216217
# xgboost helpers --------------------------------------------------------------
217218

218219
#' Boosted trees via xgboost
@@ -256,11 +257,11 @@ check_args.boost_tree <- function(object) {
256257
#' @keywords internal
257258
#' @export
258259
xgb_train <- function(
259-
x, y,
260-
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
261-
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
262-
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
263-
event_level = c("first", "second"), ...) {
260+
x, y,
261+
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
262+
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
263+
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
264+
event_level = c("first", "second"), weights = NULL, ...) {
264265

265266
event_level <- rlang::arg_match(event_level, c("first", "second"))
266267
others <- list(...)
@@ -295,7 +296,11 @@ xgb_train <- function(
295296
n <- nrow(x)
296297
p <- ncol(x)
297298

298-
x <- as_xgb_data(x, y, validation, event_level)
299+
x <-
300+
as_xgb_data(x, y,
301+
validation = validation,
302+
event_level = event_level,
303+
weights = weights)
299304

300305

301306
if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
@@ -405,7 +410,7 @@ xgb_predict <- function(object, new_data, ...) {
405410
}
406411

407412

408-
as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
413+
as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) {
409414
lvls <- levels(y)
410415
n <- nrow(x)
411416

@@ -428,22 +433,36 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
428433

429434
if (!inherits(x, "xgb.DMatrix")) {
430435
if (validation > 0) {
436+
# Split data
431437
m <- floor(n * (1 - validation)) + 1
432438
trn_index <- sample(1:n, size = max(m, 2))
433-
wlist <-
434-
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
435-
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
439+
val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA)
440+
watch_list <- list(validation = val_data)
441+
442+
info_list <- list(label = y[trn_index])
443+
if (!is.null(weights)) {
444+
info_list$weight <- weights[trn_index]
445+
}
446+
dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list)
447+
436448

437449
} else {
438-
dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
439-
wlist <- list(training = dat)
450+
info_list <- list(label = y)
451+
if (!is.null(weights)) {
452+
info_list$weight <- weights
453+
}
454+
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
455+
watch_list <- list(training = dat)
440456
}
441457
} else {
442458
dat <- xgboost::setinfo(x, "label", y)
443-
wlist <- list(training = dat)
459+
if (!is.null(weights)) {
460+
dat <- xgboost::setinfo(x, "weight", weights)
461+
}
462+
watch_list <- list(training = dat)
444463
}
445464

446-
list(data = dat, watchlist = wlist)
465+
list(data = dat, watchlist = watch_list)
447466
}
448467

449468
get_event_level <- function(model_spec){
@@ -456,6 +475,7 @@ get_event_level <- function(model_spec){
456475
event_level
457476
}
458477

478+
459479
#' @export
460480
#' @rdname multi_predict
461481
#' @param trees An integer vector for the number of trees in the ensemble.

man/xgb_train.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)