@@ -260,7 +260,7 @@ xgb_train <- function(
260
260
max_depth = 6 , nrounds = 15 , eta = 0.3 , colsample_bynode = NULL ,
261
261
colsample_bytree = NULL , min_child_weight = 1 , gamma = 0 , subsample = 1 ,
262
262
validation = 0 , early_stop = NULL , objective = NULL , counts = TRUE ,
263
- event_level = c(" first" , " second" ), ... ) {
263
+ event_level = c(" first" , " second" ), weights = NULL , ... ) {
264
264
265
265
event_level <- rlang :: arg_match(event_level , c(" first" , " second" ))
266
266
others <- list (... )
@@ -295,7 +295,11 @@ xgb_train <- function(
295
295
n <- nrow(x )
296
296
p <- ncol(x )
297
297
298
- x <- as_xgb_data(x , y , validation , event_level )
298
+ x <-
299
+ as_xgb_data(x , y ,
300
+ validation = validation ,
301
+ event_level = event_level ,
302
+ weights = weights )
299
303
300
304
301
305
if (! is.numeric(subsample ) || subsample < 0 || subsample > 1 ) {
@@ -401,7 +405,7 @@ xgb_pred <- function(object, newdata, ...) {
401
405
}
402
406
403
407
404
- as_xgb_data <- function (x , y , validation = 0 , event_level = " first" , ... ) {
408
+ as_xgb_data <- function (x , y , validation = 0 , weights = NULL , event_level = " first" , ... ) {
405
409
lvls <- levels(y )
406
410
n <- nrow(x )
407
411
@@ -424,22 +428,36 @@ as_xgb_data <- function(x, y, validation = 0, event_level = "first", ...) {
424
428
425
429
if (! inherits(x , " xgb.DMatrix" )) {
426
430
if (validation > 0 ) {
431
+ # Split data
427
432
m <- floor(n * (1 - validation )) + 1
428
433
trn_index <- sample(1 : n , size = max(m , 2 ))
429
- wlist <-
430
- list (validation = xgboost :: xgb.DMatrix(x [- trn_index , ], label = y [- trn_index ], missing = NA ))
431
- dat <- xgboost :: xgb.DMatrix(x [trn_index , ], label = y [trn_index ], missing = NA )
434
+ val_data <- xgboost :: xgb.DMatrix(x [- trn_index ,], label = y [- trn_index ], missing = NA )
435
+ watch_list <- list (validation = val_data )
436
+
437
+ info_list <- list (label = y [trn_index ])
438
+ if (! is.null(weights )) {
439
+ info_list $ weight <- weights [trn_index ]
440
+ }
441
+ dat <- xgboost :: xgb.DMatrix(x [trn_index ,], missing = NA , info = info_list )
442
+
432
443
433
444
} else {
434
- dat <- xgboost :: xgb.DMatrix(x , label = y , missing = NA )
435
- wlist <- list (training = dat )
445
+ info_list <- list (label = y )
446
+ if (! is.null(weights )) {
447
+ info_list $ weight <- weights
448
+ }
449
+ dat <- xgboost :: xgb.DMatrix(x , missing = NA , info = info_list )
450
+ watch_list <- list (training = dat )
436
451
}
437
452
} else {
438
453
dat <- xgboost :: setinfo(x , " label" , y )
439
- wlist <- list (training = dat )
454
+ if (! is.null(weights )) {
455
+ dat <- xgboost :: setinfo(x , " weight" , weights )
456
+ }
457
+ watch_list <- list (training = dat )
440
458
}
441
459
442
- list (data = dat , watchlist = wlist )
460
+ list (data = dat , watchlist = watch_list )
443
461
}
444
462
445
463
get_event_level <- function (model_spec ){
0 commit comments