@@ -19,27 +19,64 @@ test_that('xgboost execution, classification', {
19
19
20
20
skip_if_not_installed(" xgboost" )
21
21
22
- expect_error(
23
- res <- parsnip :: fit(
22
+ set.seed(1 )
23
+ wts <- ifelse(runif(nrow(hpc )) < .1 , 0 , 1 )
24
+ wts <- importance_weights(wts )
25
+
26
+ expect_error({
27
+ set.seed(1 )
28
+ res_f <- parsnip :: fit(
24
29
hpc_xgboost ,
25
30
class ~ compounds + input_fields ,
26
31
data = hpc ,
27
32
control = ctrl
28
- ),
29
- regexp = NA
33
+ )
34
+ },
35
+ regexp = NA
30
36
)
31
- expect_error(
32
- res <- parsnip :: fit_xy(
37
+ expect_error({
38
+ set.seed(1 )
39
+ res_xy <- parsnip :: fit_xy(
33
40
hpc_xgboost ,
34
- x = hpc [, num_pred ],
41
+ x = hpc [, c( " compounds " , " input_fields " ) ],
35
42
y = hpc $ class ,
36
43
control = ctrl
37
- ),
38
- regexp = NA
44
+ )
45
+ },
46
+ regexp = NA
47
+ )
48
+ expect_error({
49
+ set.seed(1 )
50
+ res_f_wts <- parsnip :: fit(
51
+ hpc_xgboost ,
52
+ class ~ compounds + input_fields ,
53
+ data = hpc ,
54
+ control = ctrl ,
55
+ case_weights = wts
56
+ )
57
+ },
58
+ regexp = NA
59
+ )
60
+ expect_error({
61
+ set.seed(1 )
62
+ res_xy_wts <- parsnip :: fit_xy(
63
+ hpc_xgboost ,
64
+ x = hpc [, c(" compounds" , " input_fields" )],
65
+ y = hpc $ class ,
66
+ control = ctrl ,
67
+ case_weights = wts
68
+ )
69
+ },
70
+ regexp = NA
39
71
)
40
72
41
- expect_true(has_multi_predict(res ))
42
- expect_equal(multi_predict_args(res ), " trees" )
73
+ expect_equal(res_f $ fit $ evaluation_log , res_xy $ fit $ evaluation_log )
74
+ expect_equal(res_f_wts $ fit $ evaluation_log , res_xy_wts $ fit $ evaluation_log )
75
+ # Check to see if the case weights had an effect
76
+ expect_true(! isTRUE(all.equal(res_f $ fit $ evaluation_log , res_f_wts $ fit $ evaluation_log )))
77
+
78
+ expect_true(has_multi_predict(res_xy ))
79
+ expect_equal(multi_predict_args(res_xy ), " trees" )
43
80
44
81
expect_error(
45
82
res <- parsnip :: fit(
@@ -312,6 +349,7 @@ test_that('xgboost data conversion', {
312
349
mtcar_x <- mtcars [, - 1 ]
313
350
mtcar_mat <- as.matrix(mtcar_x )
314
351
mtcar_smat <- Matrix :: Matrix(mtcar_mat , sparse = TRUE )
352
+ wts <- 1 : 32
315
353
316
354
expect_error(from_df <- parsnip ::: as_xgb_data(mtcar_x , mtcars $ mpg ), regexp = NA )
317
355
expect_true(inherits(from_df $ data , " xgb.DMatrix" ))
@@ -352,6 +390,13 @@ test_that('xgboost data conversion', {
352
390
expect_warning(from_df <- parsnip ::: as_xgb_data(mtcar_x , mtcars_y , event_level = " second" ),
353
391
regexp = " `event_level` can only be set for binary variables." )
354
392
393
+ # case weights added
394
+ expect_error(wted <- parsnip ::: as_xgb_data(mtcar_x , mtcars $ mpg , weights = wts ), regexp = NA )
395
+ expect_equal(wts , xgboost :: getinfo(wted $ data , " weight" ))
396
+ expect_error(wted_val <- parsnip ::: as_xgb_data(mtcar_x , mtcars $ mpg , weights = wts , validation = 1 / 4 ), regexp = NA )
397
+ expect_true(all(xgboost :: getinfo(wted_val $ data , " weight" ) %in% wts ))
398
+ expect_null(xgboost :: getinfo(wted_val $ watchlist $ validation , " weight" ))
399
+
355
400
})
356
401
357
402
@@ -361,6 +406,7 @@ test_that('xgboost data and sparse matrices', {
361
406
mtcar_x <- mtcars [, - 1 ]
362
407
mtcar_mat <- as.matrix(mtcar_x )
363
408
mtcar_smat <- Matrix :: Matrix(mtcar_mat , sparse = TRUE )
409
+ wts <- 1 : 32
364
410
365
411
xgb_spec <-
366
412
boost_tree(trees = 10 ) %> %
@@ -377,6 +423,13 @@ test_that('xgboost data and sparse matrices', {
377
423
expect_equal(from_df $ fit , from_mat $ fit )
378
424
expect_equal(from_df $ fit , from_sparse $ fit )
379
425
426
+ # case weights added
427
+ expect_error(wted <- parsnip ::: as_xgb_data(mtcar_smat , mtcars $ mpg , weights = wts ), regexp = NA )
428
+ expect_equal(wts , xgboost :: getinfo(wted $ data , " weight" ))
429
+ expect_error(wted_val <- parsnip ::: as_xgb_data(mtcar_smat , mtcars $ mpg , weights = wts , validation = 1 / 4 ), regexp = NA )
430
+ expect_true(all(xgboost :: getinfo(wted_val $ data , " weight" ) %in% wts ))
431
+ expect_null(xgboost :: getinfo(wted_val $ watchlist $ validation , " weight" ))
432
+
380
433
})
381
434
382
435
0 commit comments