@@ -35,8 +35,8 @@ test_that("MCMC BCF", {
35
35
X_train <- X [train_inds ,]
36
36
Z_test <- Z [test_inds ]
37
37
Z_train <- Z [train_inds ]
38
- pi_test <- pi [test_inds ]
39
- pi_train <- pi [train_inds ]
38
+ pi_test <- pi_X [test_inds ]
39
+ pi_train <- pi_X [train_inds ]
40
40
mu_test <- mu_X [test_inds ]
41
41
mu_train <- mu_X [train_inds ]
42
42
tau_test <- tau_X [test_inds ]
@@ -53,6 +53,32 @@ test_that("MCMC BCF", {
53
53
num_mcmc = 10 , general_params = general_param_list )
54
54
)
55
55
56
+ # 1 chain, no thinning, matrix leaf scale parameter provided
57
+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
58
+ mu_forest_param_list <- list (sigma2_leaf_init = as.matrix(0.5 ))
59
+ tau_forest_param_list <- list (sigma2_leaf_init = as.matrix(0.5 ))
60
+ expect_no_error(
61
+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
62
+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
63
+ propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
64
+ num_mcmc = 10 , general_params = general_param_list ,
65
+ mu_forest_params = mu_forest_param_list ,
66
+ tau_forest_params = tau_forest_param_list )
67
+ )
68
+
69
+ # 1 chain, no thinning, scalar leaf scale parameter provided
70
+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
71
+ mu_forest_param_list <- list (sigma2_leaf_init = 0.5 )
72
+ tau_forest_param_list <- list (sigma2_leaf_init = 0.5 )
73
+ expect_no_error(
74
+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
75
+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
76
+ propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
77
+ num_mcmc = 10 , general_params = general_param_list ,
78
+ mu_forest_params = mu_forest_param_list ,
79
+ tau_forest_params = tau_forest_param_list )
80
+ )
81
+
56
82
# 3 chains, no thinning
57
83
general_param_list <- list (num_chains = 3 , keep_every = 1 )
58
84
expect_no_error(
@@ -118,8 +144,8 @@ test_that("GFR BCF", {
118
144
X_train <- X [train_inds ,]
119
145
Z_test <- Z [test_inds ]
120
146
Z_train <- Z [train_inds ]
121
- pi_test <- pi [test_inds ]
122
- pi_train <- pi [train_inds ]
147
+ pi_test <- pi_X [test_inds ]
148
+ pi_train <- pi_X [train_inds ]
123
149
mu_test <- mu_X [test_inds ]
124
150
mu_train <- mu_X [train_inds ]
125
151
tau_test <- tau_X [test_inds ]
@@ -219,8 +245,8 @@ test_that("Warmstart BCF", {
219
245
X_train <- X [train_inds ,]
220
246
Z_test <- Z [test_inds ]
221
247
Z_train <- Z [train_inds ]
222
- pi_test <- pi [test_inds ]
223
- pi_train <- pi [train_inds ]
248
+ pi_test <- pi_X [test_inds ]
249
+ pi_train <- pi_X [train_inds ]
224
250
mu_test <- mu_X [test_inds ]
225
251
mu_train <- mu_X [train_inds ]
226
252
tau_test <- tau_X [test_inds ]
@@ -287,8 +313,8 @@ test_that("Warmstart BCF", {
287
313
X_train <- X [train_inds ,]
288
314
Z_test <- Z [test_inds ]
289
315
Z_train <- Z [train_inds ]
290
- pi_test <- pi [test_inds ]
291
- pi_train <- pi [train_inds ]
316
+ pi_test <- pi_X [test_inds ]
317
+ pi_train <- pi_X [train_inds ]
292
318
mu_test <- mu_X [test_inds ]
293
319
mu_train <- mu_X [train_inds ]
294
320
tau_test <- tau_X [test_inds ]
@@ -329,3 +355,75 @@ test_that("Warmstart BCF", {
329
355
general_params = general_param_list )
330
356
)
331
357
})
358
+
359
+ test_that(" Multivariate Treatment MCMC BCF" , {
360
+ skip_on_cran()
361
+
362
+ # Generate simulated data
363
+ n <- 100
364
+ p <- 5
365
+ X <- matrix (runif(n * p ), ncol = p )
366
+ mu_X <- (
367
+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (- 7.5 ) +
368
+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (- 2.5 ) +
369
+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (2.5 ) +
370
+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (7.5 )
371
+ )
372
+ pi_X_1 <- (
373
+ ((0 < = X [,1 ]) & (0.25 > X [,1 ])) * (0.2 ) +
374
+ ((0.25 < = X [,1 ]) & (0.5 > X [,1 ])) * (0.4 ) +
375
+ ((0.5 < = X [,1 ]) & (0.75 > X [,1 ])) * (0.6 ) +
376
+ ((0.75 < = X [,1 ]) & (1 > X [,1 ])) * (0.8 )
377
+ )
378
+ pi_X_2 <- (
379
+ ((0 < = X [,2 ]) & (0.25 > X [,2 ])) * (0.8 ) +
380
+ ((0.25 < = X [,2 ]) & (0.5 > X [,2 ])) * (0.4 ) +
381
+ ((0.5 < = X [,2 ]) & (0.75 > X [,2 ])) * (0.6 ) +
382
+ ((0.75 < = X [,2 ]) & (1 > X [,2 ])) * (0.2 )
383
+ )
384
+ pi_X <- cbind(pi_X_1 , pi_X_2 )
385
+ tau_X_1 <- (
386
+ ((0 < = X [,2 ]) & (0.25 > X [,2 ])) * (0.5 ) +
387
+ ((0.25 < = X [,2 ]) & (0.5 > X [,2 ])) * (1.0 ) +
388
+ ((0.5 < = X [,2 ]) & (0.75 > X [,2 ])) * (1.5 ) +
389
+ ((0.75 < = X [,2 ]) & (1 > X [,2 ])) * (2.0 )
390
+ )
391
+ tau_X_2 <- (
392
+ ((0 < = X [,3 ]) & (0.25 > X [,3 ])) * (- 0.5 ) +
393
+ ((0.25 < = X [,3 ]) & (0.5 > X [,3 ])) * (- 1.5 ) +
394
+ ((0.5 < = X [,3 ]) & (0.75 > X [,3 ])) * (- 1.0 ) +
395
+ ((0.75 < = X [,3 ]) & (1 > X [,3 ])) * (0.0 )
396
+ )
397
+ tau_X <- cbind(tau_X_1 , tau_X_2 )
398
+ Z_1 <- as.numeric(rbinom(n , 1 , pi_X_1 ))
399
+ Z_2 <- as.numeric(rbinom(n , 1 , pi_X_2 ))
400
+ Z <- cbind(Z_1 , Z_2 )
401
+ noise_sd <- 1
402
+ y <- mu_X + rowSums(tau_X * Z ) + rnorm(n , 0 , noise_sd )
403
+ test_set_pct <- 0.2
404
+ n_test <- round(test_set_pct * n )
405
+ n_train <- n - n_test
406
+ test_inds <- sort(sample(1 : n , n_test , replace = FALSE ))
407
+ train_inds <- (1 : n )[! ((1 : n ) %in% test_inds )]
408
+ X_test <- X [test_inds ,]
409
+ X_train <- X [train_inds ,]
410
+ Z_test <- Z [test_inds ,]
411
+ Z_train <- Z [train_inds ,]
412
+ pi_test <- pi_X [test_inds ,]
413
+ pi_train <- pi_X [train_inds ,]
414
+ mu_test <- mu_X [test_inds ]
415
+ mu_train <- mu_X [train_inds ]
416
+ tau_test <- tau_X [test_inds ,]
417
+ tau_train <- tau_X [train_inds ,]
418
+ y_test <- y [test_inds ]
419
+ y_train <- y [train_inds ]
420
+
421
+ # 1 chain, no thinning
422
+ general_param_list <- list (num_chains = 1 , keep_every = 1 )
423
+ expect_error(
424
+ bcf_model <- bcf(X_train = X_train , y_train = y_train , Z_train = Z_train ,
425
+ propensity_train = pi_train , X_test = X_test , Z_test = Z_test ,
426
+ propensity_test = pi_test , num_gfr = 0 , num_burnin = 10 ,
427
+ num_mcmc = 10 , general_params = general_param_list )
428
+ )
429
+ })
0 commit comments