Skip to content

Overhaul sampler - forest interaction to allow correct burn-in, thinning, and warm-start #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
1a328ba
Initial overhaul of sampler - forest interaction to allow correct bur…
andrewherren Oct 24, 2024
5f13424
Expanding "active forest" interface
andrewherren Oct 28, 2024
e95c533
Not-fully-functional update
andrewherren Oct 29, 2024
cc20dba
Fixed bug in the active forest interface
andrewherren Oct 30, 2024
a0906c7
Updated BART and BCF function calls
andrewherren Oct 30, 2024
149620a
Fixed R unit tests to use the new "active forest" interface
andrewherren Oct 30, 2024
5d27541
Incomplete update of python interface
andrewherren Oct 31, 2024
9a249f8
Updates to the python interface
andrewherren Oct 31, 2024
ac27432
Updates to BCF and BART python implementations
andrewherren Oct 31, 2024
e096c42
Updating python interface
andrewherren Nov 4, 2024
9d2796e
Support discarding burn-in / GFR samples directly in the R bart inter…
andrewherren Nov 4, 2024
baec700
Updated BCF interface to burn-in correctly
andrewherren Nov 4, 2024
5c9077d
Updated keep_gfr documented default for BART
andrewherren Nov 4, 2024
d16b5b9
Added thinning to the R interface
andrewherren Nov 5, 2024
1da658e
Merge branch 'main' into expand-sampler-api
andrewherren Nov 5, 2024
c1d6c3c
Fixed incorrect merge
andrewherren Nov 5, 2024
624f8f2
Fixed R bart thinning
andrewherren Nov 5, 2024
baf98a7
Updated python interface and demos
andrewherren Nov 5, 2024
4d0dd07
Updated demo notebook
andrewherren Nov 5, 2024
a2ae2e1
Updated python test suite
andrewherren Nov 5, 2024
cf2af65
WIP update of the warm start overhaul
andrewherren Nov 8, 2024
6b8175f
Initial working implementation of proper warm-start functionality
andrewherren Nov 10, 2024
85464e6
Updated demo warmstart script
andrewherren Nov 11, 2024
dcff264
Merge branch 'main' into expand-sampler-api
andrewherren Nov 11, 2024
825936e
Merge branch 'main' into expand-sampler-api
andrewherren Nov 11, 2024
02afd5c
Merge branch 'main' into expand-sampler-api
andrewherren Nov 11, 2024
1ab49cf
Updated citations in python prototype interface vignette
andrewherren Nov 12, 2024
99173a6
Added ability to run multiple chains from root (still debugging python)
andrewherren Nov 13, 2024
57cbb15
Updated warmstart MCMC in R and Python
andrewherren Nov 14, 2024
8bca3fb
Merge branch 'main' into expand-sampler-api
andrewherren Nov 14, 2024
6b41dac
Extended the python tree-inspection interface to Forest objects as well
andrewherren Nov 14, 2024
729d432
Merge branch 'main' into expand-sampler-api
andrewherren Nov 15, 2024
c0e9d55
Merge branch 'main' into expand-sampler-api
andrewherren Nov 15, 2024
0b409f2
Added C++ code to remove samples from a forest container
andrewherren Nov 20, 2024
0ac99a2
Added functionality to remove samples from forest and RFX containers
andrewherren Nov 20, 2024
8962e7e
Edited warmstart interface to properly reset RFX terms
andrewherren Nov 20, 2024
253d2fb
Updated BART interface to make parallel warmstart straightforward
andrewherren Nov 22, 2024
fe15fc5
Updated parallel warmstart vignette
andrewherren Nov 22, 2024
a6c220b
Updated multi-chain vignette
andrewherren Nov 22, 2024
5f53d2e
Fixed python bugs; removing GFR samples post-hoc if requested
andrewherren Nov 23, 2024
d425ee5
Updated R functions to post-hoc remove GFR samples if requested
andrewherren Nov 23, 2024
c6fc41b
Not-fully-working BCF update
andrewherren Nov 26, 2024
6545608
Updated R package
andrewherren Dec 2, 2024
adc64ab
Updated vignettes and package
andrewherren Dec 2, 2024
dfb781e
Updated R package to pass CRAN checks (had to manually remove some pr…
andrewherren Dec 3, 2024
9cb107a
Update to pass C++ unit tests
andrewherren Dec 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
LinkingTo:
cpp11
cpp11, BH
Suggests:
doParallel,
foreach,
Expand Down
10 changes: 10 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ export(createBARTModelFromCombinedJsonString)
export(createBARTModelFromJson)
export(createBARTModelFromJsonFile)
export(createBARTModelFromJsonString)
export(createBCFModelFromCombinedJsonString)
export(createBCFModelFromJson)
export(createBCFModelFromJsonFile)
export(createBCFModelFromJsonString)
export(createCppJson)
export(createCppJsonFile)
export(createCppJsonString)
export(createForest)
export(createForestContainer)
export(createForestCovariates)
export(createForestCovariatesFromMetadata)
Expand Down Expand Up @@ -55,13 +57,21 @@ export(preprocessPredictionMatrix)
export(preprocessTrainData)
export(preprocessTrainDataFrame)
export(preprocessTrainMatrix)
export(resetActiveForest)
export(resetForestModel)
export(resetRandomEffectsModel)
export(resetRandomEffectsTracker)
export(rootResetActiveForest)
export(rootResetRandomEffectsModel)
export(rootResetRandomEffectsTracker)
export(sample_sigma2_one_iteration)
export(sample_tau_one_iteration)
export(saveBARTModelToJsonFile)
export(saveBARTModelToJsonString)
export(saveBCFModelToJsonFile)
export(saveBCFModelToJsonString)
importFrom(R6,R6Class)
importFrom(stats,coef)
importFrom(stats,lm)
importFrom(stats,model.matrix)
importFrom(stats,qgamma)
Expand Down
433 changes: 301 additions & 132 deletions R/bart.R

Large diffs are not rendered by default.

793 changes: 590 additions & 203 deletions R/bcf.R

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions R/calibration.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#' Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022) [1]
#' Calibrate the scale parameter on an inverse gamma prior for the global error variance as in Chipman et al (2022)
#'
#' [1] Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288
#' Chipman, H., George, E., Hahn, R., McCulloch, R., Pratola, M. and Sparapani, R. (2022). Bayesian Additive Regression Trees, Computational Approaches. In Wiley StatsRef: Statistics Reference Online (eds N. Balakrishnan, T. Colton, B. Everitt, W. Piegorsch, F. Ruggeri and J.L. Teugels). https://doi.org/10.1002/9781118445112.stat08288
#'
#' @param y Outcome to be modeled using BART, BCF or another nonparametric ensemble method.
#' @param X Covariates to be used to partition trees in an ensemble or series of ensemble.
#' @param W [Optional] Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: `NULL`.
#' @param W (Optional) Basis used to define a "leaf regression" model for each decision tree. The "classic" BART model assumes a constant leaf parameter, which is equivalent to a "leaf regression" on a basis of all ones, though it is not necessary to pass a vector of ones, here or to the BART function. Default: `NULL`.
#' @param nu The shape parameter for the global error variance's IG prior. The scale parameter in the Sparapani et al (2021) parameterization is defined as `nu*lambda` where `lambda` is the output of this function. Default: `3`.
#' @param quant [Optional] Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of `sigma^2`. Default: `0.9`.
#' @param standardize [Optional] Whether or not outcome should be standardized (`(y-mean(y))/sd(y)`) before calibration of `lambda`. Default: `TRUE`.
#' @param quant (Optional) Quantile of the inverse gamma prior distribution represented by a linear-regression-based overestimate of `sigma^2`. Default: `0.9`.
#' @param standardize (Optional) Whether or not outcome should be standardized (`(y-mean(y))/sd(y)`) before calibration of `lambda`. Default: `TRUE`.
#'
#' @return Value of `lambda` which determines the scale parameter of the global error variance prior (`sigma^2 ~ IG(nu,nu*lambda)`)
#' @export
Expand Down
136 changes: 128 additions & 8 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,8 @@ rfx_label_mapper_cpp <- function(rfx_tracker) {
.Call(`_stochtree_rfx_label_mapper_cpp`, rfx_tracker)
}

rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng) {
invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, global_variance, rng))
rfx_model_sample_random_effects_cpp <- function(rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, keep_sample, global_variance, rng) {
invisible(.Call(`_stochtree_rfx_model_sample_random_effects_cpp`, rfx_model, rfx_dataset, residual, rfx_tracker, rfx_container, keep_sample, global_variance, rng))
}

rfx_model_predict_cpp <- function(rfx_model, rfx_dataset, rfx_tracker) {
Expand All @@ -168,6 +168,10 @@ rfx_container_num_groups_cpp <- function(rfx_container) {
.Call(`_stochtree_rfx_container_num_groups_cpp`, rfx_container)
}

rfx_container_delete_sample_cpp <- function(rfx_container, sample_num) {
invisible(.Call(`_stochtree_rfx_container_delete_sample_cpp`, rfx_container, sample_num))
}

rfx_model_set_working_parameter_cpp <- function(rfx_model, working_param_init) {
invisible(.Call(`_stochtree_rfx_model_set_working_parameter_cpp`, rfx_model, working_param_init))
}
Expand Down Expand Up @@ -216,6 +220,22 @@ rfx_label_mapper_to_list_cpp <- function(label_mapper_ptr) {
.Call(`_stochtree_rfx_label_mapper_to_list_cpp`, label_mapper_ptr)
}

reset_rfx_model_cpp <- function(rfx_model, rfx_container, sample_num) {
invisible(.Call(`_stochtree_reset_rfx_model_cpp`, rfx_model, rfx_container, sample_num))
}

reset_rfx_tracker_cpp <- function(tracker, dataset, residual, rfx_model) {
invisible(.Call(`_stochtree_reset_rfx_tracker_cpp`, tracker, dataset, residual, rfx_model))
}

root_reset_rfx_tracker_cpp <- function(tracker, dataset, residual, rfx_model) {
invisible(.Call(`_stochtree_root_reset_rfx_tracker_cpp`, tracker, dataset, residual, rfx_model))
}

active_forest_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) {
.Call(`_stochtree_active_forest_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated)
}

forest_container_cpp <- function(num_trees, output_dimension, is_leaf_constant, is_exponentiated) {
.Call(`_stochtree_forest_container_cpp`, num_trees, output_dimension, is_leaf_constant, is_exponentiated)
}
Expand Down Expand Up @@ -280,6 +300,10 @@ is_leaf_constant_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_is_leaf_constant_forest_container_cpp`, forest_samples)
}

is_exponentiated_forest_container_cpp <- function(forest_samples) {
.Call(`_stochtree_is_exponentiated_forest_container_cpp`, forest_samples)
}

all_roots_forest_container_cpp <- function(forest_samples, forest_num) {
.Call(`_stochtree_all_roots_forest_container_cpp`, forest_samples, forest_num)
}
Expand Down Expand Up @@ -412,6 +436,10 @@ propagate_basis_update_forest_container_cpp <- function(data, residual, forest_s
invisible(.Call(`_stochtree_propagate_basis_update_forest_container_cpp`, data, residual, forest_samples, tracker, forest_num))
}

remove_sample_forest_container_cpp <- function(forest_samples, forest_num) {
invisible(.Call(`_stochtree_remove_sample_forest_container_cpp`, forest_samples, forest_num))
}

predict_forest_cpp <- function(forest_samples, dataset) {
.Call(`_stochtree_predict_forest_cpp`, forest_samples, dataset)
}
Expand All @@ -428,6 +456,98 @@ predict_forest_raw_single_tree_cpp <- function(forest_samples, dataset, forest_n
.Call(`_stochtree_predict_forest_raw_single_tree_cpp`, forest_samples, dataset, forest_num, tree_num)
}

predict_active_forest_cpp <- function(active_forest, dataset) {
.Call(`_stochtree_predict_active_forest_cpp`, active_forest, dataset)
}

predict_raw_active_forest_cpp <- function(active_forest, dataset) {
.Call(`_stochtree_predict_raw_active_forest_cpp`, active_forest, dataset)
}

output_dimension_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_output_dimension_active_forest_cpp`, active_forest)
}

average_max_depth_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_average_max_depth_active_forest_cpp`, active_forest)
}

num_trees_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_num_trees_active_forest_cpp`, active_forest)
}

ensemble_tree_max_depth_active_forest_cpp <- function(active_forest, tree_num) {
.Call(`_stochtree_ensemble_tree_max_depth_active_forest_cpp`, active_forest, tree_num)
}

is_leaf_constant_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_is_leaf_constant_active_forest_cpp`, active_forest)
}

is_exponentiated_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_is_exponentiated_active_forest_cpp`, active_forest)
}

all_roots_active_forest_cpp <- function(active_forest) {
.Call(`_stochtree_all_roots_active_forest_cpp`, active_forest)
}

set_leaf_value_active_forest_cpp <- function(active_forest, leaf_value) {
invisible(.Call(`_stochtree_set_leaf_value_active_forest_cpp`, active_forest, leaf_value))
}

set_leaf_vector_active_forest_cpp <- function(active_forest, leaf_vector) {
invisible(.Call(`_stochtree_set_leaf_vector_active_forest_cpp`, active_forest, leaf_vector))
}

add_numeric_split_tree_value_active_forest_cpp <- function(active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value) {
invisible(.Call(`_stochtree_add_numeric_split_tree_value_active_forest_cpp`, active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_value, right_leaf_value))
}

add_numeric_split_tree_vector_active_forest_cpp <- function(active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_vector, right_leaf_vector) {
invisible(.Call(`_stochtree_add_numeric_split_tree_vector_active_forest_cpp`, active_forest, tree_num, leaf_num, feature_num, split_threshold, left_leaf_vector, right_leaf_vector))
}

get_tree_leaves_active_forest_cpp <- function(active_forest, tree_num) {
.Call(`_stochtree_get_tree_leaves_active_forest_cpp`, active_forest, tree_num)
}

get_tree_split_counts_active_forest_cpp <- function(active_forest, tree_num, num_features) {
.Call(`_stochtree_get_tree_split_counts_active_forest_cpp`, active_forest, tree_num, num_features)
}

get_overall_split_counts_active_forest_cpp <- function(active_forest, num_features) {
.Call(`_stochtree_get_overall_split_counts_active_forest_cpp`, active_forest, num_features)
}

get_granular_split_count_array_active_forest_cpp <- function(active_forest, num_features) {
.Call(`_stochtree_get_granular_split_count_array_active_forest_cpp`, active_forest, num_features)
}

initialize_forest_model_active_forest_cpp <- function(data, residual, active_forest, tracker, init_values, leaf_model_int) {
invisible(.Call(`_stochtree_initialize_forest_model_active_forest_cpp`, data, residual, active_forest, tracker, init_values, leaf_model_int))
}

adjust_residual_active_forest_cpp <- function(data, residual, active_forest, tracker, requires_basis, add) {
invisible(.Call(`_stochtree_adjust_residual_active_forest_cpp`, data, residual, active_forest, tracker, requires_basis, add))
}

propagate_basis_update_active_forest_cpp <- function(data, residual, active_forest, tracker) {
invisible(.Call(`_stochtree_propagate_basis_update_active_forest_cpp`, data, residual, active_forest, tracker))
}

reset_active_forest_cpp <- function(active_forest, forest_samples, forest_num) {
invisible(.Call(`_stochtree_reset_active_forest_cpp`, active_forest, forest_samples, forest_num))
}

reset_forest_model_cpp <- function(forest_tracker, forest, data, residual, is_mean_model) {
invisible(.Call(`_stochtree_reset_forest_model_cpp`, forest_tracker, forest, data, residual, is_mean_model))
}

root_reset_active_forest_cpp <- function(active_forest) {
invisible(.Call(`_stochtree_root_reset_active_forest_cpp`, active_forest))
}

forest_container_get_max_leaf_index_cpp <- function(forest_container, forest_num) {
.Call(`_stochtree_forest_container_get_max_leaf_index_cpp`, forest_container, forest_num)
}
Expand All @@ -436,20 +556,20 @@ compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums)
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
}

sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
invisible(.Call(`_stochtree_sample_gfr_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
}

sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized))
sample_mcmc_one_iteration_cpp <- function(data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized) {
invisible(.Call(`_stochtree_sample_mcmc_one_iteration_cpp`, data, residual, forest_samples, active_forest, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, keep_forest, pre_initialized))
}

sample_sigma2_one_iteration_cpp <- function(residual, dataset, rng, a, b) {
.Call(`_stochtree_sample_sigma2_one_iteration_cpp`, residual, dataset, rng, a, b)
}

sample_tau_one_iteration_cpp <- function(forest_samples, rng, a, b, sample_num) {
.Call(`_stochtree_sample_tau_one_iteration_cpp`, forest_samples, rng, a, b, sample_num)
sample_tau_one_iteration_cpp <- function(active_forest, rng, a, b) {
.Call(`_stochtree_sample_tau_one_iteration_cpp`, active_forest, rng, a, b)
}

rng_cpp <- function(random_seed) {
Expand Down
Loading
Loading