|
| 1 | +library(stochtree) |
| 2 | + |
| 3 | +# Generate the data |
| 4 | +n <- 500 |
| 5 | +p_x <- 10 |
| 6 | +p_w <- 2 |
| 7 | +snr <- 3 |
| 8 | +X <- matrix(runif(n*p_x), ncol = p_x) |
| 9 | +W <- matrix(runif(n*p_w), ncol = p_w) |
| 10 | +f_XW <- ( |
| 11 | + ((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) + |
| 12 | + ((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) + |
| 13 | + ((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) + |
| 14 | + ((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1]) |
| 15 | +) |
| 16 | +noise_sd <- sd(f_XW) / snr |
| 17 | +y <- f_XW + rnorm(n, 0, 1)*noise_sd |
| 18 | + |
| 19 | +# Split data into test and train sets |
| 20 | +test_set_pct <- 0.2 |
| 21 | +n_test <- round(test_set_pct*n) |
| 22 | +n_train <- n - n_test |
| 23 | +test_inds <- sort(sample(1:n, n_test, replace = FALSE)) |
| 24 | +train_inds <- (1:n)[!((1:n) %in% test_inds)] |
| 25 | +X_test <- as.data.frame(X[test_inds,]) |
| 26 | +X_train <- as.data.frame(X[train_inds,]) |
| 27 | +W_test <- W[test_inds,] |
| 28 | +W_train <- W[train_inds,] |
| 29 | +y_test <- y[test_inds] |
| 30 | +y_train <- y[train_inds] |
| 31 | + |
| 32 | +# Sample BART model |
| 33 | +num_gfr <- 10 |
| 34 | +num_burnin <- 0 |
| 35 | +num_mcmc <- 100 |
| 36 | +num_samples <- num_gfr + num_burnin + num_mcmc |
| 37 | +bart_params <- list(sample_sigma_global = T, sample_sigma_leaf = F, num_trees_mean = 100) |
| 38 | +bart_model_warmstart <- stochtree::bart( |
| 39 | + X_train = X_train, W_train = W_train, y_train = y_train, X_test = X_test, W_test = W_test, |
| 40 | + num_gfr = num_gfr, num_burnin = num_burnin, num_mcmc = num_mcmc, |
| 41 | + params = bart_params |
| 42 | +) |
0 commit comments