Skip to content

Commit 2100e8f

Browse files
committed
pass case weights to xgboost
1 parent 2929d53 commit 2100e8f

File tree

2 files changed

+30
-12
lines changed

2 files changed

+30
-12
lines changed

R/boost_tree.R

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ xgb_train <- function(
260260
max_depth = 6, nrounds = 15, eta = 0.3, colsample_bynode = NULL,
261261
colsample_bytree = NULL, min_child_weight = 1, gamma = 0, subsample = 1,
262262
validation = 0, early_stop = NULL, objective = NULL, counts = TRUE,
263-
event_level = c("first", "second"), ...) {
263+
event_level = c("first", "second"), weights = NULL, ...) {
264264

265265
event_level <- rlang::arg_match(event_level, c("first", "second"))
266266
others <- list(...)
@@ -295,7 +295,11 @@ xgb_train <- function(
295295
n <- nrow(x)
296296
p <- ncol(x)
297297

298-
x <- as_xgb_data(x, y, validation, event_level)
298+
x <-
299+
as_xgb_data(x, y,
300+
validation = validation,
301+
event_level = event_level,
302+
weights = weights)
299303

300304

301305
if (!is.numeric(subsample) || subsample < 0 || subsample > 1) {
@@ -401,7 +405,7 @@ xgb_pred <- function(object, newdata, ...) {
401405
}
402406

403407

404-
as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
408+
as_xgb_data <- function(x, y, validation = 0, weights = NULL, event_level = "first", ...) {
405409
lvls <- levels(y)
406410
n <- nrow(x)
407411

@@ -424,22 +428,36 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
424428

425429
if (!inherits(x, "xgb.DMatrix")) {
426430
if (validation > 0) {
431+
# Split data
427432
m <- floor(n * (1 - validation)) + 1
428433
trn_index <- sample(1:n, size = max(m, 2))
429-
wlist <-
430-
list(validation = xgboost::xgb.DMatrix(x[-trn_index, ], label = y[-trn_index], missing = NA))
431-
dat <- xgboost::xgb.DMatrix(x[trn_index, ], label = y[trn_index], missing = NA)
434+
val_data <- xgboost::xgb.DMatrix(x[-trn_index,], label = y[-trn_index], missing = NA)
435+
watch_list <- list(validation = val_data)
436+
437+
info_list <- list(label = y[trn_index])
438+
if (!is.null(weights)) {
439+
info_list$weight <- weights[trn_index]
440+
}
441+
dat <- xgboost::xgb.DMatrix(x[trn_index,], missing = NA, info = info_list)
442+
432443

433444
} else {
434-
dat <- xgboost::xgb.DMatrix(x, label = y, missing = NA)
435-
wlist <- list(training = dat)
445+
info_list <- list(label = y)
446+
if (!is.null(weights)) {
447+
info_list$weight <- weights
448+
}
449+
dat <- xgboost::xgb.DMatrix(x, missing = NA, info = info_list)
450+
watch_list <- list(training = dat)
436451
}
437452
} else {
438453
dat <- xgboost::setinfo(x, "label", y)
439-
wlist <- list(training = dat)
454+
if (!is.null(weights)) {
455+
dat <- xgboost::setinfo(x, "weight", weights)
456+
}
457+
watch_list <- list(training = dat)
440458
}
441459

442-
list(data = dat, watchlist = wlist)
460+
list(data = dat, watchlist = watch_list)
443461
}
444462

445463
get_event_level <- function(model_spec){

R/boost_tree_data.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ set_fit(
8282
mode = "regression",
8383
value = list(
8484
interface = "matrix",
85-
protect = c("x", "y"),
85+
protect = c("x", "y", "weights"),
8686
func = c(pkg = "parsnip", fun = "xgb_train"),
8787
defaults = list(nthread = 1, verbose = 0)
8888
)
@@ -132,7 +132,7 @@ set_fit(
132132
mode = "classification",
133133
value = list(
134134
interface = "matrix",
135-
protect = c("x", "y"),
135+
protect = c("x", "y", "weights"),
136136
func = c(pkg = "parsnip", fun = "xgb_train"),
137137
defaults = list(nthread = 1, verbose = 0)
138138
)

0 commit comments

Comments
 (0)