Skip to content

Commit 99d1bc3

Browse files
committed
update tests for xgboost/boost_tree args
1 parent 2100e8f commit 99d1bc3

File tree

3 files changed

+65
-105
lines changed

3 files changed

+65
-105
lines changed

man/xgb_train.Rd

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/test_boost_tree.R

Lines changed: 0 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -11,100 +11,6 @@ hpc <- hpc_data[1:150, c(2:5, 8)]
1111

1212
# ------------------------------------------------------------------------------
1313

14-
test_that('primary arguments', {
15-
basic <- boost_tree(mode = "classification")
16-
basic_xgboost <- translate(basic %>% set_engine("xgboost"))
17-
basic_C5.0 <- translate(basic %>% set_engine("C5.0"))
18-
expect_equal(basic_xgboost$method$fit$args,
19-
list(
20-
x = expr(missing_arg()),
21-
y = expr(missing_arg()),
22-
nthread = 1,
23-
verbose = 0
24-
)
25-
)
26-
expect_equal(basic_C5.0$method$fit$args,
27-
list(
28-
x = expr(missing_arg()),
29-
y = expr(missing_arg()),
30-
weights = expr(missing_arg())
31-
)
32-
)
33-
34-
trees <- boost_tree(trees = 15, mode = "classification")
35-
trees_C5.0 <- translate(trees %>% set_engine("C5.0"))
36-
trees_xgboost <- translate(trees %>% set_engine("xgboost"))
37-
expect_equal(trees_C5.0$method$fit$args,
38-
list(
39-
x = expr(missing_arg()),
40-
y = expr(missing_arg()),
41-
weights = expr(missing_arg()),
42-
trials = new_empty_quosure(15)
43-
)
44-
)
45-
expect_equal(trees_xgboost$method$fit$args,
46-
list(
47-
x = expr(missing_arg()),
48-
y = expr(missing_arg()),
49-
nrounds = new_empty_quosure(15),
50-
nthread = 1,
51-
verbose = 0
52-
)
53-
)
54-
55-
split_num <- boost_tree(min_n = 15, mode = "classification")
56-
split_num_C5.0 <- translate(split_num %>% set_engine("C5.0"))
57-
split_num_xgboost <- translate(split_num %>% set_engine("xgboost"))
58-
expect_equal(split_num_C5.0$method$fit$args,
59-
list(
60-
x = expr(missing_arg()),
61-
y = expr(missing_arg()),
62-
weights = expr(missing_arg()),
63-
minCases = new_empty_quosure(15)
64-
)
65-
)
66-
expect_equal(split_num_xgboost$method$fit$args,
67-
list(
68-
x = expr(missing_arg()),
69-
y = expr(missing_arg()),
70-
min_child_weight = new_empty_quosure(15),
71-
nthread = 1,
72-
verbose = 0
73-
)
74-
)
75-
76-
})
77-
78-
test_that('engine arguments', {
79-
xgboost_print <- boost_tree(mode = "regression")
80-
expect_equal(
81-
translate(
82-
xgboost_print %>%
83-
set_engine("xgboost", print_every_n = 10L))$method$fit$args,
84-
list(
85-
x = expr(missing_arg()),
86-
y = expr(missing_arg()),
87-
print_every_n = new_empty_quosure(10L),
88-
nthread = 1,
89-
verbose = 0
90-
)
91-
)
92-
93-
C5.0_rules <- boost_tree(mode = "classification")
94-
expect_equal(
95-
translate(
96-
C5.0_rules %>% set_engine("C5.0", rules = TRUE))$method$fit$args,
97-
list(
98-
x = expr(missing_arg()),
99-
y = expr(missing_arg()),
100-
weights = expr(missing_arg()),
101-
rules = new_empty_quosure(TRUE)
102-
)
103-
)
104-
105-
})
106-
107-
10814
test_that('updating', {
10915
expr1 <- boost_tree() %>% set_engine("xgboost", verbose = 0)
11016
expr1_exp <- boost_tree(trees = 10) %>% set_engine("xgboost", verbose = 0)

tests/testthat/test_boost_tree_xgboost.R

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,64 @@ test_that('xgboost execution, classification', {
1919

2020
skip_if_not_installed("xgboost")
2121

22-
expect_error(
23-
res <- parsnip::fit(
22+
set.seed(1)
23+
wts <- ifelse(runif(nrow(hpc)) < .1, 0, 1)
24+
wts <- importance_weights(wts)
25+
26+
expect_error({
27+
set.seed(1)
28+
res_f <- parsnip::fit(
2429
hpc_xgboost,
2530
class ~ compounds + input_fields,
2631
data = hpc,
2732
control = ctrl
28-
),
29-
regexp = NA
33+
)
34+
},
35+
regexp = NA
3036
)
31-
expect_error(
32-
res <- parsnip::fit_xy(
37+
expect_error({
38+
set.seed(1)
39+
res_xy <- parsnip::fit_xy(
3340
hpc_xgboost,
34-
x = hpc[, num_pred],
41+
x = hpc[, c("compounds", "input_fields")],
3542
y = hpc$class,
3643
control = ctrl
37-
),
38-
regexp = NA
44+
)
45+
},
46+
regexp = NA
47+
)
48+
expect_error({
49+
set.seed(1)
50+
res_f_wts <- parsnip::fit(
51+
hpc_xgboost,
52+
class ~ compounds + input_fields,
53+
data = hpc,
54+
control = ctrl,
55+
case_weights = wts
56+
)
57+
},
58+
regexp = NA
59+
)
60+
expect_error({
61+
set.seed(1)
62+
res_xy_wts <- parsnip::fit_xy(
63+
hpc_xgboost,
64+
x = hpc[, c("compounds", "input_fields")],
65+
y = hpc$class,
66+
control = ctrl,
67+
case_weights = wts
68+
)
69+
},
70+
regexp = NA
3971
)
4072

41-
expect_true(has_multi_predict(res))
42-
expect_equal(multi_predict_args(res), "trees")
73+
expect_equal(res_f$fit$evaluation_log, res_xy$fit$evaluation_log)
74+
expect_equal(res_f_wts$fit$evaluation_log, res_xy_wts$fit$evaluation_log)
75+
# Check to see if the case weights had an effect
76+
expect_true(!isTRUE(all.equal(res_f$fit$evaluation_log, res_f_wts$fit$evaluation_log)))
77+
78+
expect_true(has_multi_predict(res_xy))
79+
expect_equal(multi_predict_args(res_xy), "trees")
4380

4481
expect_error(
4582
res <- parsnip::fit(
@@ -312,6 +349,7 @@ test_that('xgboost data conversion', {
312349
mtcar_x <- mtcars[, -1]
313350
mtcar_mat <- as.matrix(mtcar_x)
314351
mtcar_smat <- Matrix::Matrix(mtcar_mat, sparse = TRUE)
352+
wts <- 1:32
315353

316354
expect_error(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg), regexp = NA)
317355
expect_true(inherits(from_df$data, "xgb.DMatrix"))
@@ -352,6 +390,13 @@ test_that('xgboost data conversion', {
352390
expect_warning(from_df <- parsnip:::as_xgb_data(mtcar_x, mtcars_y, event_level = "second"),
353391
regexp = "`event_level` can only be set for binary variables.")
354392

393+
# case weights added
394+
expect_error(wted <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, weights = wts), regexp = NA)
395+
expect_equal(wts, xgboost::getinfo(wted$data, "weight"))
396+
expect_error(wted_val <- parsnip:::as_xgb_data(mtcar_x, mtcars$mpg, weights = wts, validation = 1/4), regexp = NA)
397+
expect_true(all(xgboost::getinfo(wted_val$data, "weight") %in% wts))
398+
expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight"))
399+
355400
})
356401

357402

@@ -361,6 +406,7 @@ test_that('xgboost data and sparse matrices', {
361406
mtcar_x <- mtcars[, -1]
362407
mtcar_mat <- as.matrix(mtcar_x)
363408
mtcar_smat <- Matrix::Matrix(mtcar_mat, sparse = TRUE)
409+
wts <- 1:32
364410

365411
xgb_spec <-
366412
boost_tree(trees = 10) %>%
@@ -377,6 +423,13 @@ test_that('xgboost data and sparse matrices', {
377423
expect_equal(from_df$fit, from_mat$fit)
378424
expect_equal(from_df$fit, from_sparse$fit)
379425

426+
# case weights added
427+
expect_error(wted <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, weights = wts), regexp = NA)
428+
expect_equal(wts, xgboost::getinfo(wted$data, "weight"))
429+
expect_error(wted_val <- parsnip:::as_xgb_data(mtcar_smat, mtcars$mpg, weights = wts, validation = 1/4), regexp = NA)
430+
expect_true(all(xgboost::getinfo(wted_val$data, "weight") %in% wts))
431+
expect_null(xgboost::getinfo(wted_val$watchlist$validation, "weight"))
432+
380433
})
381434

382435

0 commit comments

Comments
 (0)