Skip to content

Addressing first round of CRAN comments for 0.1.0 release #143

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,17 @@ Authors@R:
person("Jared", "Murray", role = "aut"),
person("Carlos", "Carvalho", role = "aut"),
person("Jingyu", "He", role = "aut"),
person("stochtree contributors", role = c("cph"))
person("Pedro", "Lima", role = "ctb"),
person("stochtree", "contributors", role = c("cph")),
person("Eigen", "contributors", role = c("cph"), comment = "C++ source uses the Eigen library for matrix operations, see inst/COPYRIGHTS"),
person("xgboost", "contributors", role = c("cph"), comment = "C++ tree code and related operations include or are inspired by code from the xgboost library, see inst/COPYRIGHTS"),
person("treelite", "contributors", role = c("cph"), comment = "C++ tree code and related operations include or are inspired by code from the treelite library, see inst/COPYRIGHTS"),
person("Microsoft", "Corporation", role = c("cph"), comment = "C++ I/O and various project structure code include or are inspired by code from the LightGBM library, which is a copyright of Microsoft, see inst/COPYRIGHTS"),
person("Niels", "Lohmann", role = c("cph"), comment = "C++ source uses the JSON for Modern C++ library for JSON operations, see inst/COPYRIGHTS"),
person("Daniel", "Lemire", role = c("cph"), comment = "C++ source uses the fast_double_parser library internally, see inst/COPYRIGHTS"),
person("Victor", "Zverovich", role = c("cph"), comment = "C++ source uses the fmt library internally, see inst/COPYRIGHTS")
)
Copyright: Copyright details for stochtree's C++ dependencies, which are vendored along with the core stochtree source code, are detailed in inst/COPYRIGHTS
Description: Flexible stochastic tree ensemble software.
Robust implementations of Bayesian Additive Regression Trees (BART)
Chipman, George, McCulloch (2010) <doi:10.1214/09-AOAS285>
Expand Down
86 changes: 40 additions & 46 deletions R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,6 @@
#' y_train <- y[train_inds]
#' bart_model <- bart(X_train = X_train, y_train = y_train, X_test = X_test,
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' plot(rowMeans(bart_model$y_hat_test), y_test, xlab = "predicted", ylab = "actual")
#' abline(0,1,col="red",lty=3,lwd=3)
bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train = NULL,
rfx_basis_train = NULL, X_test = NULL, leaf_basis_test = NULL,
rfx_group_ids_test = NULL, rfx_basis_test = NULL,
Expand All @@ -110,12 +108,12 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
variance_forest_params = list()) {
# Update general BART parameters
general_params_default <- list(
cutpoint_grid_size = 100, standardize = T,
sample_sigma2_global = T, sigma2_global_init = NULL,
cutpoint_grid_size = 100, standardize = TRUE,
sample_sigma2_global = TRUE, sigma2_global_init = NULL,
sigma2_global_shape = 0, sigma2_global_scale = 0,
variable_weights = NULL, random_seed = -1,
keep_burnin = F, keep_gfr = F, keep_every = 1,
num_chains = 1, verbose = F
keep_burnin = FALSE, keep_gfr = FALSE, keep_every = 1,
num_chains = 1, verbose = FALSE
)
general_params_updated <- preprocessParams(
general_params_default, general_params
Expand All @@ -125,7 +123,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
mean_forest_params_default <- list(
num_trees = 200, alpha = 0.95, beta = 2.0,
min_samples_leaf = 5, max_depth = 10,
sample_sigma2_leaf = T, sigma2_leaf_init = NULL,
sample_sigma2_leaf = TRUE, sigma2_leaf_init = NULL,
sigma2_leaf_shape = 3, sigma2_leaf_scale = NULL,
keep_vars = NULL, drop_vars = NULL
)
Expand Down Expand Up @@ -197,7 +195,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

# Override keep_gfr if there are no MCMC samples
if (num_mcmc == 0) keep_gfr <- T
if (num_mcmc == 0) keep_gfr <- TRUE

# Check if previous model JSON is provided and parse it if so
has_prev_model <- !is.null(previous_model_json)
Expand Down Expand Up @@ -238,10 +236,10 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

# Determine whether conditional mean, variance, or both will be modeled
if (num_trees_variance > 0) include_variance_forest = T
else include_variance_forest = F
if (num_trees_mean > 0) include_mean_forest = T
else include_mean_forest = F
if (num_trees_variance > 0) include_variance_forest = TRUE
else include_variance_forest = FALSE
if (num_trees_mean > 0) include_mean_forest = TRUE
else include_mean_forest = FALSE

# Set the variance forest priors if not set
if (include_variance_forest) {
Expand All @@ -253,7 +251,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

# Override tau sampling if there is no mean forest
if (!include_mean_forest) sample_sigma_leaf <- F
if (!include_mean_forest) sample_sigma_leaf <- FALSE

# Variable weight preprocessing (and initialization if necessary)
if (is.null(variable_weights)) {
Expand Down Expand Up @@ -388,19 +386,19 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
has_rfx <- F
has_rfx_test <- F
has_rfx <- FALSE
has_rfx_test <- FALSE
if (!is.null(rfx_group_ids_train)) {
group_ids_factor <- factor(rfx_group_ids_train)
rfx_group_ids_train <- as.integer(group_ids_factor)
has_rfx <- T
has_rfx <- TRUE
if (!is.null(rfx_group_ids_test)) {
group_ids_factor_test <- factor(rfx_group_ids_test, levels = levels(group_ids_factor))
if (sum(is.na(group_ids_factor_test)) > 0) {
stop("All random effect group labels provided in rfx_group_ids_test must be present in rfx_group_ids_train")
}
rfx_group_ids_test <- as.integer(group_ids_factor_test)
has_rfx_test <- T
has_rfx_test <- TRUE
}
}

Expand Down Expand Up @@ -432,13 +430,13 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}

# Fill in rfx basis as a vector of 1s (random intercept) if a basis not provided
has_basis_rfx <- F
has_basis_rfx <- FALSE
num_basis_rfx <- 0
if (has_rfx) {
if (is.null(rfx_basis_train)) {
rfx_basis_train <- matrix(rep(1,nrow(X_train)), nrow = nrow(X_train), ncol = 1)
} else {
has_basis_rfx <- T
has_basis_rfx <- TRUE
num_basis_rfx <- ncol(rfx_basis_train)
}
num_rfx_groups <- length(unique(rfx_group_ids_train))
Expand Down Expand Up @@ -520,40 +518,40 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
# Unpack model type info
if (leaf_model_mean_forest == 0) {
leaf_dimension = 1
is_leaf_constant = T
leaf_regression = F
is_leaf_constant = TRUE
leaf_regression = FALSE
} else if (leaf_model_mean_forest == 1) {
stopifnot(has_basis)
stopifnot(ncol(leaf_basis_train) == 1)
leaf_dimension = 1
is_leaf_constant = F
leaf_regression = T
is_leaf_constant = FALSE
leaf_regression = TRUE
} else if (leaf_model_mean_forest == 2) {
stopifnot(has_basis)
stopifnot(ncol(leaf_basis_train) > 1)
leaf_dimension = ncol(leaf_basis_train)
is_leaf_constant = F
leaf_regression = T
is_leaf_constant = FALSE
leaf_regression = TRUE
if (sample_sigma_leaf) {
warning("Sampling leaf scale not yet supported for multivariate leaf models, so the leaf scale parameter will not be sampled in this model.")
sample_sigma_leaf <- F
sample_sigma_leaf <- FALSE
}
}

# Data
if (leaf_regression) {
forest_dataset_train <- createForestDataset(X_train, leaf_basis_train)
if (has_test) forest_dataset_test <- createForestDataset(X_test, leaf_basis_test)
requires_basis <- T
requires_basis <- TRUE
} else {
forest_dataset_train <- createForestDataset(X_train)
if (has_test) forest_dataset_test <- createForestDataset(X_test)
requires_basis <- F
requires_basis <- FALSE
}
outcome_train <- createOutcome(resid_train)

# Random number generator (std::mt19937)
if (is.null(random_seed)) random_seed = sample(1:10000,1,F)
if (is.null(random_seed)) random_seed = sample(1:10000,1,FALSE)
rng <- createCppRNG(random_seed)

# Sampling data structures
Expand Down Expand Up @@ -630,7 +628,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
if (requires_basis) init_values_mean_forest <- rep(0., ncol(leaf_basis_train))
else init_values_mean_forest <- 0.
active_forest_mean$prepare_for_sampler(forest_dataset_train, outcome_train, forest_model_mean, leaf_model_mean_forest, init_values_mean_forest)
active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, F)
active_forest_mean$adjust_residual(forest_dataset_train, outcome_train, forest_model_mean, requires_basis, FALSE)
}

# Initialize the leaves of each tree in the variance forest
Expand All @@ -643,8 +641,8 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
if (num_gfr > 0){
for (i in 1:num_gfr) {
# Keep all GFR samples at this stage -- remove from ForestSamples after MCMC
# keep_sample <- ifelse(keep_gfr, T, F)
keep_sample <- T
# keep_sample <- ifelse(keep_gfr, TRUE, FALSE)
keep_sample <- TRUE
if (keep_sample) sample_counter <- sample_counter + 1
# Print progress
if (verbose) {
Expand All @@ -657,14 +655,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = T
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = T
global_model_config = global_model_config, keep_forest = keep_sample, gfr = TRUE
)
}
if (sample_sigma_global) {
Expand Down Expand Up @@ -771,11 +769,11 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
is_mcmc <- i > (num_gfr + num_burnin)
if (is_mcmc) {
mcmc_counter <- i - (num_gfr + num_burnin)
if (mcmc_counter %% keep_every == 0) keep_sample <- T
else keep_sample <- F
if (mcmc_counter %% keep_every == 0) keep_sample <- TRUE
else keep_sample <- FALSE
} else {
if (keep_burnin) keep_sample <- T
else keep_sample <- F
if (keep_burnin) keep_sample <- TRUE
else keep_sample <- FALSE
}
if (keep_sample) sample_counter <- sample_counter + 1
# Print progress
Expand All @@ -796,14 +794,14 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
forest_model_mean$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_mean,
active_forest = active_forest_mean, rng = rng, forest_model_config = forest_model_config_mean,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = F
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)
}
if (include_variance_forest) {
forest_model_variance$sample_one_iteration(
forest_dataset = forest_dataset_train, residual = outcome_train, forest_samples = forest_samples_variance,
active_forest = active_forest_variance, rng = rng, forest_model_config = forest_model_config_variance,
global_model_config = global_model_config, keep_forest = keep_sample, gfr = F
global_model_config = global_model_config, keep_forest = keep_sample, gfr = FALSE
)
}
if (sample_sigma_global) {
Expand Down Expand Up @@ -994,8 +992,6 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
#' bart_model <- bart(X_train = X_train, y_train = y_train,
#' num_gfr = 10, num_burnin = 0, num_mcmc = 10)
#' y_hat_test <- predict(bart_model, X_test)$y_hat
#' plot(rowMeans(y_hat_test), y_test, xlab = "predicted", ylab = "actual")
#' abline(0,1,col="red",lty=3,lwd=3)
predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL, rfx_basis = NULL, ...){
# Preprocess covariates
if ((!is.data.frame(X)) && (!is.matrix(X))) {
Expand Down Expand Up @@ -1033,15 +1029,15 @@ predict.bartmodel <- function(object, X, leaf_basis = NULL, rfx_group_ids = NULL
}

# Recode group IDs to integer vector (if passed as, for example, a vector of county names, etc...)
has_rfx <- F
has_rfx <- FALSE
if (!is.null(rfx_group_ids)) {
rfx_unique_group_ids <- object$rfx_unique_group_ids
group_ids_factor <- factor(rfx_group_ids, levels = rfx_unique_group_ids)
if (sum(is.na(group_ids_factor)) > 0) {
stop("All random effect group labels provided in rfx_group_ids must be present in rfx_group_ids_train")
}
rfx_group_ids <- as.integer(group_ids_factor)
has_rfx <- T
has_rfx <- TRUE
}

# Produce basis for the "intercept-only" random effects case
Expand Down Expand Up @@ -1557,8 +1553,6 @@ createBARTModelFromJsonFile <- function(json_filename){
#' bart_json <- saveBARTModelToJsonString(bart_model)
#' bart_model_roundtrip <- createBARTModelFromJsonString(bart_json)
#' y_hat_mean_roundtrip <- rowMeans(predict(bart_model_roundtrip, X_train)$y_hat)
#' plot(rowMeans(bart_model$y_hat_train), y_hat_mean_roundtrip,
#' xlab = "original", ylab = "roundtrip")
createBARTModelFromJsonString <- function(json_string){
# Load a `CppJson` object from string
bart_json <- createCppJsonString(json_string)
Expand Down
Loading
Loading