Skip to content

Updated handling of scalar-valued leaf scale parameters #137

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -484,14 +484,26 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
if (has_basis) {
if (ncol(leaf_basis_train) > 1) {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- diag(var(resid_train)/(num_trees_mean), ncol(leaf_basis_train))
current_leaf_scale <- sigma_leaf_init
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, ncol(leaf_basis_train)))
} else {
current_leaf_scale <- sigma_leaf_init
}
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
current_leaf_scale <- as.matrix(sigma_leaf_init)
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
} else {
if (is.null(sigma_leaf_init)) sigma_leaf_init <- var(resid_train)/(num_trees_mean)
current_leaf_scale <- as.matrix(sigma_leaf_init)
if (is.null(sigma_leaf_init)) sigma_leaf_init <- as.matrix(var(resid_train)/(num_trees_mean))
if (!is.matrix(sigma_leaf_init)) {
current_leaf_scale <- as.matrix(diag(sigma_leaf_init, 1))
} else {
current_leaf_scale <- sigma_leaf_init
}
}
current_sigma2 <- sigma2_init

Expand Down Expand Up @@ -522,7 +534,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
is_leaf_constant = F
leaf_regression = T
if (sample_sigma_leaf) {
stop("Sampling leaf scale not yet supported for multivariate leaf models")
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.")
sample_sigma_leaf <- F
}
}

Expand Down
49 changes: 39 additions & 10 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,9 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
}
}
}

# Stop if multivariate treatment is provided
if (ncol(Z_train) > 1) stop("Multivariate treatments are not currently supported")

# Random effects covariance prior
if (has_rfx) {
Expand Down Expand Up @@ -650,20 +653,20 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
# Update feature_types and covariates
feature_types <- as.integer(feature_types)
if (propensity_covariate != "none") {
feature_types <- as.integer(c(feature_types,0))
feature_types <- as.integer(c(feature_types,rep(0, ncol(propensity_train))))
X_train <- cbind(X_train, propensity_train)
if (propensity_covariate == "mu") {
variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train)))
variable_weights_tau <- c(variable_weights_tau, 0)
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0)
variable_weights_tau <- c(variable_weights_tau, rep(0, ncol(propensity_train)))
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train)))
} else if (propensity_covariate == "tau") {
variable_weights_mu <- c(variable_weights_mu, 0)
variable_weights_mu <- c(variable_weights_mu, rep(0, ncol(propensity_train)))
variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train)))
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0)
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train)))
} else if (propensity_covariate == "both") {
variable_weights_mu <- c(variable_weights_mu, rep(1./num_cov_orig, ncol(propensity_train)))
variable_weights_tau <- c(variable_weights_tau, rep(1./num_cov_orig, ncol(propensity_train)))
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, 0)
if (include_variance_forest) variable_weights_variance <- c(variable_weights_variance, rep(0, ncol(propensity_train)))
}
if (has_test) X_test <- cbind(X_test, propensity_test)
}
Expand All @@ -690,11 +693,37 @@ bcf <- function(X_train, Z_train, y_train, propensity_train = NULL, rfx_group_id
if (is.null(variance_forest_init)) variance_forest_init <- 1.0*var(resid_train)
if (is.null(b_leaf_mu)) b_leaf_mu <- var(resid_train)/(num_trees_mu)
if (is.null(b_leaf_tau)) b_leaf_tau <- var(resid_train)/(2*num_trees_tau)
if (is.null(sigma_leaf_mu)) sigma_leaf_mu <- var(resid_train)/(num_trees_mu)
if (is.null(sigma_leaf_tau)) sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau)
if (is.null(sigma_leaf_mu)) {
sigma_leaf_mu <- var(resid_train)/(num_trees_mu)
current_leaf_scale_mu <- as.matrix(sigma_leaf_mu)
} else {
if (!is.matrix(sigma_leaf_mu)) {
current_leaf_scale_mu <- as.matrix(sigma_leaf_mu)
} else {
current_leaf_scale_mu <- sigma_leaf_mu
}
}
if (is.null(sigma_leaf_tau)) {
sigma_leaf_tau <- var(resid_train)/(2*num_trees_tau)
current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train)))
} else {
if (!is.matrix(sigma_leaf_tau)) {
current_leaf_scale_tau <- as.matrix(diag(sigma_leaf_tau, ncol(Z_train)))
} else {
if (ncol(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix")
if (nrow(sigma_leaf_tau) != ncol(Z_train)) stop("sigma_leaf_init for the tau forest must have the same number of columns / rows as columns in the Z_train matrix")
current_leaf_scale_tau <- sigma_leaf_tau
}
}
current_sigma2 <- sigma2_init
current_leaf_scale_mu <- as.matrix(sigma_leaf_mu)
current_leaf_scale_tau <- as.matrix(sigma_leaf_tau)

# Switch off leaf scale sampling for multivariate treatments
if (ncol(Z_train) > 1) {
if (sample_sigma_leaf_tau) {
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled for the treatment forest in this model.")
sample_sigma_leaf_tau <- F
}
}

# Set mu and tau leaf models / dimensions
leaf_model_mu_forest <- 0
Expand Down
59 changes: 59 additions & 0 deletions test/R/testthat/test-bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,65 @@ test_that("MCMC BART", {
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
general_params = general_param_list)
)

# Generate simulated data with a leaf basis
n <- 100
p <- 5
p_w <- 2
X <- matrix(runif(n*p), ncol = p)
W <- matrix(runif(n*p_w), ncol = p_w)
f_XW <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5*W[,1]) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5*W[,1]) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5*W[,1]) +
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5*W[,1])
)
noise_sd <- 1
y <- f_XW + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
W_test <- W[test_inds,]
W_train <- W[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]

# 3 chains, thinning, leaf regression
general_param_list <- list(num_chains = 3, keep_every = 5)
mean_forest_param_list <- list(sample_sigma2_leaf = F)
expect_no_error(
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
leaf_basis_train = W_train, leaf_basis_test = W_test,
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
general_params = general_param_list,
mean_forest_params = mean_forest_param_list)
)

# 3 chains, thinning, leaf regression with a scalar leaf scale
general_param_list <- list(num_chains = 3, keep_every = 5)
mean_forest_param_list <- list(sample_sigma2_leaf = F, sigma2_leaf_init = 0.5)
expect_no_error(
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
leaf_basis_train = W_train, leaf_basis_test = W_test,
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
general_params = general_param_list,
mean_forest_params = mean_forest_param_list)
)

# 3 chains, thinning, leaf regression with a scalar leaf scale, random leaf scale
general_param_list <- list(num_chains = 3, keep_every = 5)
mean_forest_param_list <- list(sample_sigma2_leaf = T, sigma2_leaf_init = 0.5)
expect_warning(
bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
leaf_basis_train = W_train, leaf_basis_test = W_test,
num_gfr = 0, num_burnin = 10, num_mcmc = 10,
general_params = general_param_list,
mean_forest_params = mean_forest_param_list)
)
})

test_that("GFR BART", {
Expand Down
114 changes: 106 additions & 8 deletions test/R/testthat/test-bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ test_that("MCMC BCF", {
X_train <- X[train_inds,]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
pi_test <- pi[test_inds]
pi_train <- pi[train_inds]
pi_test <- pi_X[test_inds]
pi_train <- pi_X[train_inds]
mu_test <- mu_X[test_inds]
mu_train <- mu_X[train_inds]
tau_test <- tau_X[test_inds]
Expand All @@ -53,6 +53,32 @@ test_that("MCMC BCF", {
num_mcmc = 10, general_params = general_param_list)
)

# 1 chain, no thinning, matrix leaf scale parameter provided
general_param_list <- list(num_chains = 1, keep_every = 1)
mu_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5))
tau_forest_param_list <- list(sigma2_leaf_init = as.matrix(0.5))
expect_no_error(
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
num_mcmc = 10, general_params = general_param_list,
mu_forest_params = mu_forest_param_list,
tau_forest_params = tau_forest_param_list)
)

# 1 chain, no thinning, scalar leaf scale parameter provided
general_param_list <- list(num_chains = 1, keep_every = 1)
mu_forest_param_list <- list(sigma2_leaf_init = 0.5)
tau_forest_param_list <- list(sigma2_leaf_init = 0.5)
expect_no_error(
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
num_mcmc = 10, general_params = general_param_list,
mu_forest_params = mu_forest_param_list,
tau_forest_params = tau_forest_param_list)
)

# 3 chains, no thinning
general_param_list <- list(num_chains = 3, keep_every = 1)
expect_no_error(
Expand Down Expand Up @@ -118,8 +144,8 @@ test_that("GFR BCF", {
X_train <- X[train_inds,]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
pi_test <- pi[test_inds]
pi_train <- pi[train_inds]
pi_test <- pi_X[test_inds]
pi_train <- pi_X[train_inds]
mu_test <- mu_X[test_inds]
mu_train <- mu_X[train_inds]
tau_test <- tau_X[test_inds]
Expand Down Expand Up @@ -219,8 +245,8 @@ test_that("Warmstart BCF", {
X_train <- X[train_inds,]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
pi_test <- pi[test_inds]
pi_train <- pi[train_inds]
pi_test <- pi_X[test_inds]
pi_train <- pi_X[train_inds]
mu_test <- mu_X[test_inds]
mu_train <- mu_X[train_inds]
tau_test <- tau_X[test_inds]
Expand Down Expand Up @@ -287,8 +313,8 @@ test_that("Warmstart BCF", {
X_train <- X[train_inds,]
Z_test <- Z[test_inds]
Z_train <- Z[train_inds]
pi_test <- pi[test_inds]
pi_train <- pi[train_inds]
pi_test <- pi_X[test_inds]
pi_train <- pi_X[train_inds]
mu_test <- mu_X[test_inds]
mu_train <- mu_X[train_inds]
tau_test <- tau_X[test_inds]
Expand Down Expand Up @@ -329,3 +355,75 @@ test_that("Warmstart BCF", {
general_params = general_param_list)
)
})

test_that("Multivariate Treatment MCMC BCF", {
skip_on_cran()

# Generate simulated data
n <- 100
p <- 5
X <- matrix(runif(n*p), ncol = p)
mu_X <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (-7.5) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (-2.5) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (2.5) +
((0.75 <= X[,1]) & (1 > X[,1])) * (7.5)
)
pi_X_1 <- (
((0 <= X[,1]) & (0.25 > X[,1])) * (0.2) +
((0.25 <= X[,1]) & (0.5 > X[,1])) * (0.4) +
((0.5 <= X[,1]) & (0.75 > X[,1])) * (0.6) +
((0.75 <= X[,1]) & (1 > X[,1])) * (0.8)
)
pi_X_2 <- (
((0 <= X[,2]) & (0.25 > X[,2])) * (0.8) +
((0.25 <= X[,2]) & (0.5 > X[,2])) * (0.4) +
((0.5 <= X[,2]) & (0.75 > X[,2])) * (0.6) +
((0.75 <= X[,2]) & (1 > X[,2])) * (0.2)
)
pi_X <- cbind(pi_X_1, pi_X_2)
tau_X_1 <- (
((0 <= X[,2]) & (0.25 > X[,2])) * (0.5) +
((0.25 <= X[,2]) & (0.5 > X[,2])) * (1.0) +
((0.5 <= X[,2]) & (0.75 > X[,2])) * (1.5) +
((0.75 <= X[,2]) & (1 > X[,2])) * (2.0)
)
tau_X_2 <- (
((0 <= X[,3]) & (0.25 > X[,3])) * (-0.5) +
((0.25 <= X[,3]) & (0.5 > X[,3])) * (-1.5) +
((0.5 <= X[,3]) & (0.75 > X[,3])) * (-1.0) +
((0.75 <= X[,3]) & (1 > X[,3])) * (0.0)
)
tau_X <- cbind(tau_X_1, tau_X_2)
Z_1 <- as.numeric(rbinom(n, 1, pi_X_1))
Z_2 <- as.numeric(rbinom(n, 1, pi_X_2))
Z <- cbind(Z_1, Z_2)
noise_sd <- 1
y <- mu_X + rowSums(tau_X*Z) + rnorm(n, 0, noise_sd)
test_set_pct <- 0.2
n_test <- round(test_set_pct*n)
n_train <- n - n_test
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
train_inds <- (1:n)[!((1:n) %in% test_inds)]
X_test <- X[test_inds,]
X_train <- X[train_inds,]
Z_test <- Z[test_inds,]
Z_train <- Z[train_inds,]
pi_test <- pi_X[test_inds,]
pi_train <- pi_X[train_inds,]
mu_test <- mu_X[test_inds]
mu_train <- mu_X[train_inds]
tau_test <- tau_X[test_inds,]
tau_train <- tau_X[train_inds,]
y_test <- y[test_inds]
y_train <- y[train_inds]

# 1 chain, no thinning
general_param_list <- list(num_chains = 1, keep_every = 1)
expect_error(
bcf_model <- bcf(X_train = X_train, y_train = y_train, Z_train = Z_train,
propensity_train = pi_train, X_test = X_test, Z_test = Z_test,
propensity_test = pi_test, num_gfr = 0, num_burnin = 10,
num_mcmc = 10, general_params = general_param_list)
)
})
Loading