Skip to content

Commit 8c5f0a2

Browse files
authored
Merge pull request #137 from StochasticTree/leaf-scale-hotfix
Updated handling of scalar-valued leaf scale parameters
2 parents 744ca40 + d8fdb72 commit 8c5f0a2

File tree

4 files changed

+223
-24
lines changed

4 files changed

+223
-24
lines changed

R/bart.R

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,14 +484,26 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
484484
if (has_basis) {
485485
if (ncol(leaf_basis_train) > 1) {
486486
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
487-
current_leaf_scale <- sigma_leaf_init
487+
if (!is.matrix(sigma_leaf_init)) {
488+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
489+
} else {
490+
current_leaf_scale <- sigma_leaf_init
491+
}
488492
} else {
489-
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
490-
current_leaf_scale <- as.matrix(sigma_leaf_init)
493+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
494+
if (!is.matrix(sigma_leaf_init)) {
495+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
496+
} else {
497+
current_leaf_scale <- sigma_leaf_init
498+
}
491499
}
492500
} else {
493-
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
494-
current_leaf_scale <- as.matrix(sigma_leaf_init)
501+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
502+
if (!is.matrix(sigma_leaf_init)) {
503+
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
504+
} else {
505+
current_leaf_scale <- sigma_leaf_init
506+
}
495507
}
496508
current_sigma2 <- sigma2_init
497509

@@ -522,7 +534,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
522534
is_leaf_constant = F
523535
leaf_regression = T
524536
if (sample_sigma_leaf) {
525-
stop("Sampling leaf scale not yet supported for multivariate leaf models")
537+
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.")
538+
sample_sigma_leaf <- F
526539
}
527540
}
528541

R/bcf.R

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
546546
}
547547
}
548548
}
549+
550+
# Stop if multivariate treatment is provided
551+
if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")
549552

550553
# Random effects covariance prior
551554
if (has_rfx) {
@@ -650,20 +653,20 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
650653
# Update feature_types and covariates
651654
feature_types <- as.integer(feature_types)
652655
if (propensity_covariate != "none") {
653-
feature_types <- as.integer(c(feature_types,0))
656+
feature_types <- as.integer(c(feature_types,rep(0, ncol(propensity_train))))
654657
X_train <- cbind(X_train, propensity_train)
655658
if (propensity_covariate == "mu") {
656659
variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train)))
657-
variable_weights_tau <- c(variable_weights_tau, 0)
658-
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0)
660+
variable_weights_tau <- c(variable_weights_tau, rep(0, ncol(propensity_train)))
661+
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train)))
659662
} else if (propensity_covariate == "tau") {
660-
variable_weights_mu <- c(variable_weights_mu, 0)
663+
variable_weights_mu <- c(variable_weights_mu, rep(0, ncol(propensity_train)))
661664
variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train)))
662-
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0)
665+
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train)))
663666
} else if (propensity_covariate == "both") {
664667
variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train)))
665668
variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train)))
666-
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0)
669+
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train)))
667670
}
668671
if (has_test) X_test <- cbind(X_test, propensity_test)
669672
}
@@ -690,11 +693,37 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
690693
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
691694
if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu)
692695
if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau)
693-
if (is.null(sigma_leaf_mu)) sigma_leaf_mu <- var(resid_train)/(num_trees_mu)
694-
if (is.null(sigma_leaf_tau)) sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau)
696+
if (is.null(sigma_leaf_mu)) {
697+
sigma_leaf_mu <- var(resid_train)/(num_trees_mu)
698+
current_leaf_scale_mu <- as.matrix(sigma_leaf_mu)
699+
} else {
700+
if (!is.matrix(sigma_leaf_mu)) {
701+
current_leaf_scale_mu <- as.matrix(sigma_leaf_mu)
702+
} else {
703+
current_leaf_scale_mu <- sigma_leaf_mu
704+
}
705+
}
706+
if (is.null(sigma_leaf_tau)) {
707+
sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau)
708+
current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train)))
709+
} else {
710+
if (!is.matrix(sigma_leaf_tau)) {
711+
current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train)))
712+
} else {
713+
if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix")
714+
if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix")
715+
current_leaf_scale_tau <- sigma_leaf_tau
716+
}
717+
}
695718
current_sigma2 <- sigma2_init
696-
current_leaf_scale_mu <- as.matrix(sigma_leaf_mu)
697-
current_leaf_scale_tau <- as.matrix(sigma_leaf_tau)
719+
720+
# Switch off leaf scale sampling for multivariate treatments
721+
if (ncol(Z_train) > 1) {
722+
if (sample_sigma_leaf_tau) {
723+
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.")
724+
sample_sigma_leaf_tau <- F
725+
}
726+
}
698727

699728
# Set mu and tau leaf models / dimensions
700729
leaf_model_mu_forest <- 0

test/R/testthat/test-bart.R

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,65 @@ test_that("MCMC BART", {
5454
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
5555
general_params = general_param_list)
5656
)
57+
58+
# Generate simulated data with a leaf basis
59+
n <- 100
60+
p <- 5
61+
p_w <- 2
62+
X <- matrix(runif(n*p), ncol = p)
63+
W <- matrix(runif(n*p_w), ncol = p_w)
64+
f_XW <- (
65+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
66+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
67+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
68+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
69+
)
70+
noise_sd <- 1
71+
y <- f_XW + rnorm(n, 0, noise_sd)
72+
test_set_pct <- 0.2
73+
n_test <- round(test_set_pct*n)
74+
n_train <- n - n_test
75+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
76+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
77+
X_test <- X[test_inds,]
78+
X_train <- X[train_inds,]
79+
W_test <- W[test_inds,]
80+
W_train <- W[train_inds,]
81+
y_test <- y[test_inds]
82+
y_train <- y[train_inds]
83+
84+
# 3 chains, thinning, leaf regression
85+
general_param_list <- list(num_chains = 3, keep_every = 5)
86+
mean_forest_param_list <- list(sample_sigma2_leaf = F)
87+
expect_no_error(
88+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
89+
leaf_basis_train = W_train, leaf_basis_test = W_test,
90+
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
91+
general_params = general_param_list,
92+
mean_forest_params = mean_forest_param_list)
93+
)
94+
95+
# 3 chains, thinning, leaf regression with a scalar leaf scale
96+
general_param_list <- list(num_chains = 3, keep_every = 5)
97+
mean_forest_param_list <- list(sample_sigma2_leaf = F, sigma2_leaf_init = 0.5)
98+
expect_no_error(
99+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
100+
leaf_basis_train = W_train, leaf_basis_test = W_test,
101+
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
102+
general_params = general_param_list,
103+
mean_forest_params = mean_forest_param_list)
104+
)
105+
106+
# 3 chains, thinning, leaf regression with a scalar leaf scale, random leaf scale
107+
general_param_list <- list(num_chains = 3, keep_every = 5)
108+
mean_forest_param_list <- list(sample_sigma2_leaf = T, sigma2_leaf_init = 0.5)
109+
expect_warning(
110+
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
111+
leaf_basis_train = W_train, leaf_basis_test = W_test,
112+
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
113+
general_params = general_param_list,
114+
mean_forest_params = mean_forest_param_list)
115+
)
57116
})
58117

59118
test_that("GFR BART", {

test/R/testthat/test-bcf.R

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ test_that("MCMC BCF", {
3535
X_train <- X[train_inds,]
3636
Z_test <- Z[test_inds]
3737
Z_train <- Z[train_inds]
38-
pi_test <- pi[test_inds]
39-
pi_train <- pi[train_inds]
38+
pi_test <- pi_X[test_inds]
39+
pi_train <- pi_X[train_inds]
4040
mu_test <- mu_X[test_inds]
4141
mu_train <- mu_X[train_inds]
4242
tau_test <- tau_X[test_inds]
@@ -53,6 +53,32 @@ test_that("MCMC BCF", {
5353
num_mcmc = 10, general_params = general_param_list)
5454
)
5555

56+
# 1 chain, no thinning, matrix leaf scale parameter provided
57+
general_param_list <- list(num_chains = 1, keep_every = 1)
58+
mu_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5))
59+
tau_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5))
60+
expect_no_error(
61+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
62+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
63+
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
64+
num_mcmc = 10, general_params = general_param_list,
65+
mu_forest_params = mu_forest_param_list,
66+
tau_forest_params = tau_forest_param_list)
67+
)
68+
69+
# 1 chain, no thinning, scalar leaf scale parameter provided
70+
general_param_list <- list(num_chains = 1, keep_every = 1)
71+
mu_forest_param_list <- list(sigma2_leaf_init = 0.5)
72+
tau_forest_param_list <- list(sigma2_leaf_init = 0.5)
73+
expect_no_error(
74+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
75+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
76+
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
77+
num_mcmc = 10, general_params = general_param_list,
78+
mu_forest_params = mu_forest_param_list,
79+
tau_forest_params = tau_forest_param_list)
80+
)
81+
5682
# 3 chains, no thinning
5783
general_param_list <- list(num_chains = 3, keep_every = 1)
5884
expect_no_error(
@@ -118,8 +144,8 @@ test_that("GFR BCF", {
118144
X_train <- X[train_inds,]
119145
Z_test <- Z[test_inds]
120146
Z_train <- Z[train_inds]
121-
pi_test <- pi[test_inds]
122-
pi_train <- pi[train_inds]
147+
pi_test <- pi_X[test_inds]
148+
pi_train <- pi_X[train_inds]
123149
mu_test <- mu_X[test_inds]
124150
mu_train <- mu_X[train_inds]
125151
tau_test <- tau_X[test_inds]
@@ -219,8 +245,8 @@ test_that("Warmstart BCF", {
219245
X_train <- X[train_inds,]
220246
Z_test <- Z[test_inds]
221247
Z_train <- Z[train_inds]
222-
pi_test <- pi[test_inds]
223-
pi_train <- pi[train_inds]
248+
pi_test <- pi_X[test_inds]
249+
pi_train <- pi_X[train_inds]
224250
mu_test <- mu_X[test_inds]
225251
mu_train <- mu_X[train_inds]
226252
tau_test <- tau_X[test_inds]
@@ -287,8 +313,8 @@ test_that("Warmstart BCF", {
287313
X_train <- X[train_inds,]
288314
Z_test <- Z[test_inds]
289315
Z_train <- Z[train_inds]
290-
pi_test <- pi[test_inds]
291-
pi_train <- pi[train_inds]
316+
pi_test <- pi_X[test_inds]
317+
pi_train <- pi_X[train_inds]
292318
mu_test <- mu_X[test_inds]
293319
mu_train <- mu_X[train_inds]
294320
tau_test <- tau_X[test_inds]
@@ -329,3 +355,75 @@ test_that("Warmstart BCF", {
329355
general_params = general_param_list)
330356
)
331357
})
358+
359+
test_that("Multivariate Treatment MCMC BCF", {
360+
skip_on_cran()
361+
362+
# Generate simulated data
363+
n <- 100
364+
p <- 5
365+
X <- matrix(runif(n*p), ncol = p)
366+
mu_X <- (
367+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
368+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
369+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
370+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
371+
)
372+
pi_X_1 <- (
373+
((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
374+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
375+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
376+
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
377+
)
378+
pi_X_2 <- (
379+
((0 <= X[,2]) & (0.25 > X[,2])) * (0.8) +
380+
((0.25 <= X[,2]) & (0.5 > X[,2])) * (0.4) +
381+
((0.5 <= X[,2]) & (0.75 > X[,2])) * (0.6) +
382+
((0.75 <= X[,2]) & (1 > X[,2])) * (0.2)
383+
)
384+
pi_X <- cbind(pi_X_1, pi_X_2)
385+
tau_X_1 <- (
386+
((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
387+
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
388+
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
389+
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
390+
)
391+
tau_X_2 <- (
392+
((0 <= X[,3]) & (0.25 > X[,3])) * (-0.5) +
393+
((0.25 <= X[,3]) & (0.5 > X[,3])) * (-1.5) +
394+
((0.5 <= X[,3]) & (0.75 > X[,3])) * (-1.0) +
395+
((0.75 <= X[,3]) & (1 > X[,3])) * (0.0)
396+
)
397+
tau_X <- cbind(tau_X_1, tau_X_2)
398+
Z_1 <- as.numeric(rbinom(n, 1, pi_X_1))
399+
Z_2 <- as.numeric(rbinom(n, 1, pi_X_2))
400+
Z <- cbind(Z_1, Z_2)
401+
noise_sd <- 1
402+
y <- mu_X + rowSums(tau_X*Z) + rnorm(n, 0, noise_sd)
403+
test_set_pct <- 0.2
404+
n_test <- round(test_set_pct*n)
405+
n_train <- n - n_test
406+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
407+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
408+
X_test <- X[test_inds,]
409+
X_train <- X[train_inds,]
410+
Z_test <- Z[test_inds,]
411+
Z_train <- Z[train_inds,]
412+
pi_test <- pi_X[test_inds,]
413+
pi_train <- pi_X[train_inds,]
414+
mu_test <- mu_X[test_inds]
415+
mu_train <- mu_X[train_inds]
416+
tau_test <- tau_X[test_inds,]
417+
tau_train <- tau_X[train_inds,]
418+
y_test <- y[test_inds]
419+
y_train <- y[train_inds]
420+
421+
# 1 chain, no thinning
422+
general_param_list <- list(num_chains = 1, keep_every = 1)
423+
expect_error(
424+
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
425+
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
426+
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
427+
num_mcmc = 10, general_params = general_param_list)
428+
)
429+
})

0 commit comments

Comments
 (0)