Skip to content

Commit fc62e0c

Browse files
authored
Merge pull request #98 from StochasticTree/multivariate_bart_hotfix
Fixed error in multivariate BART sampler
2 parents 834bddc + 2c3d69e commit fc62e0c

File tree

2 files changed

+54
-2
lines changed

2 files changed

+54
-2
lines changed

R/bart.R

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,8 +321,18 @@ bart <- function(X_train, y_train, W_train = NULL, group_ids_train = NULL,
321321
if (is.null(sigma2_init)) sigma2_init <- pct_var_sigma2_init*var(resid_train)
322322
if (is.null(variance_forest_init)) variance_forest_init <- pct_var_variance_forest_init*var(resid_train)
323323
if (is.null(b_leaf)) b_leaf <- var(resid_train)/(2*num_trees_mean)
324-
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
325-
current_leaf_scale <- as.matrix(sigma_leaf_init)
324+
if (has_basis) {
325+
if (ncol(W_train) > 1) {
326+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(W_train))
327+
current_leaf_scale <- sigma_leaf_init
328+
} else {
329+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
330+
current_leaf_scale <- as.matrix(sigma_leaf_init)
331+
}
332+
} else {
333+
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
334+
current_leaf_scale <- as.matrix(sigma_leaf_init)
335+
}
326336
current_sigma2 <- sigma2_init
327337

328338
# Determine leaf model type

tools/debug/multivariate_bart_debug.R

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
library(stochtree)
2+
3+
# Generate the data
4+
n <- 500
5+
p_x <- 10
6+
p_w <- 2
7+
snr <- 3
8+
X <- matrix(runif(n*p_x), ncol = p_x)
9+
W <- matrix(runif(n*p_w), ncol = p_w)
10+
f_XW <- (
11+
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
12+
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
13+
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
14+
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
15+
)
16+
noise_sd <- sd(f_XW) / snr
17+
y <- f_XW + rnorm(n, 0, 1)*noise_sd
18+
19+
# Split data into test and train sets
20+
test_set_pct <- 0.2
21+
n_test <- round(test_set_pct*n)
22+
n_train <- n - n_test
23+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
24+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
25+
X_test <- as.data.frame(X[test_inds,])
26+
X_train <- as.data.frame(X[train_inds,])
27+
W_test <- W[test_inds,]
28+
W_train <- W[train_inds,]
29+
y_test <- y[test_inds]
30+
y_train <- y[train_inds]
31+
32+
# Sample BART model
33+
num_gfr <- 10
34+
num_burnin <- 0
35+
num_mcmc <- 100
36+
num_samples <- num_gfr + num_burnin + num_mcmc
37+
bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = F, num_trees_mean = 100)
38+
bart_model_warmstart <- stochtree::bart(
39+
X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test,
40+
num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc,
41+
params = bart_params
42+
)

0 commit comments

Comments
 (0)