Skip to content

Updating the Python package for the 0.1.1 release #151

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 37 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
bdcae37
Formatting the core stochtree python code with ruff
andrewherren Feb 14, 2025
ea7407e
Fixing python code issues
andrewherren Feb 14, 2025
486b2fe
Formatting core stochtree python imports with ruff
andrewherren Feb 14, 2025
6595736
Used ruff to format python demo notebooks
andrewherren Feb 15, 2025
d6afe59
Used ruff to format imports in python demo notebooks
andrewherren Feb 15, 2025
6759c46
Updated demo notebooks
andrewherren Feb 15, 2025
859e04f
Removed pandas import from calibration python file
andrewherren Feb 15, 2025
207a3d0
Updated R config code and added python ForestModelConfig object code
andrewherren Feb 18, 2025
c6f51f6
Formatted new python code with ruff
andrewherren Feb 18, 2025
4f4e5bf
Added global variance parameter config
andrewherren Feb 18, 2025
c73a10f
Add utility unit tests
andrewherren Feb 20, 2025
3111b4c
Formatted test code
andrewherren Feb 20, 2025
3073a9c
Added config tests
andrewherren Feb 20, 2025
d86ae36
Partial update of sampler interface to use config objects
andrewherren Feb 20, 2025
eca9c32
Add flexibility in use of config objects in R and python interfaces
andrewherren Feb 23, 2025
74b8d6a
Refactored pre_initialized parameter out of python interface
andrewherren Feb 23, 2025
d383760
Updated python package to use config objects
andrewherren Feb 26, 2025
082b19a
Removed unnecessary code from test
andrewherren Feb 26, 2025
a9bb16d
Merge branch 'main' into python-update-0.1.1
andrewherren Feb 27, 2025
0cd663a
Update how GoogleTest is used
andrewherren Feb 28, 2025
1843ab3
Added ability to propagate scalar-valued leaf variance parameters
andrewherren Feb 28, 2025
d9efdc4
Refactored "mu" and "tau" notation out of python BCF function signature
andrewherren Feb 28, 2025
cad3dab
Added python wrappers for C++ random effects sampler
andrewherren Mar 4, 2025
96a6126
Added unit tests and demo script for random effects
andrewherren Mar 4, 2025
2aa3e90
Reformatted python code
andrewherren Mar 4, 2025
79b7d95
Refactoring "basis" parameter names in python BART interface
andrewherren Mar 6, 2025
a084306
Partial addition of random effects to BART interface
andrewherren Mar 6, 2025
93cbd65
Updated random effects so that serialization works as in R
andrewherren Mar 13, 2025
f454c1f
Placeholder kernel computations and R / Python comparison scripts
andrewherren Mar 16, 2025
f5b2c72
Merge branch 'main' into python-update-0.1.1
andrewherren Mar 17, 2025
a6caffb
Updated demo scripts
andrewherren Mar 17, 2025
ae13cb0
Updated demos and forest initialization python code
andrewherren Mar 17, 2025
eb32b11
Added kernel indices module
andrewherren Mar 18, 2025
e2413f0
Updated R kernel code
andrewherren Mar 18, 2025
c6d2aa0
Added kernel debugging script and updated kernel unit tests
andrewherren Mar 18, 2025
7bc2188
Updated kernel code
andrewherren Mar 19, 2025
a5fdcae
Updated R package
andrewherren Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,17 @@ endif()

# Build C++ test program
if(BUILD_TEST)
# Download the GoogleTest dependency if necessary
# Check if user specified a local clone of the GoogleTest repo, use Github repo if not
if (NOT DEFINED GOOGLETEST_GIT_REPO)
set(GOOGLETEST_GIT_REPO https://github.com/google/googletest.git)
endif()

# Fetch and install GoogleTest dependency
include(FetchContent)
FetchContent_Declare(
googletest
GIT_REPOSITORY https://github.com/google/googletest.git
GIT_TAG e2239ee6043f73722e7aa812a459f54a28552929 # release-1.14.0
GIT_REPOSITORY ${GOOGLETEST_GIT_REPO}
GIT_TAG 6910c9d9165801d8827d628cb72eb7ea9dd538c5 # release-1.16.0
)
# For Windows: Prevent overriding the parent project's compiler/linker settings
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# stochtree 0.1.2

* Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models
* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided

# stochtree 0.1.1

* Fixed initialization bug in several R package code examples for random effects models
Expand Down
2 changes: 1 addition & 1 deletion R/bart.R
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
}
if (has_rfx_test) {
if (is.null(rfx_basis_test)) {
if (!is.null(rfx_basis_train)) {
if (has_basis_rfx) {
stop("Random effects basis provided for training set, must also be provided for the test set")
}
rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1)
Expand Down
40 changes: 37 additions & 3 deletions R/config.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ ForestModelConfig <- R6::R6Class(

#' Create a new ForestModelConfig object.
#'
#' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
#' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
#' @param num_trees Number of trees in the forest being sampled
#' @param num_features Number of features in training dataset
#' @param num_observations Number of observations in training dataset
Expand Down Expand Up @@ -98,6 +98,12 @@ ForestModelConfig <- R6::R6Class(
warning("`variable_weights` not provided, will be assumed to be equal-weighted")
variable_weights <- rep(1/num_features, num_features)
}
if (is.null(num_trees)) {
stop("num_trees must be provided")
}
if (is.null(num_observations)) {
stop("num_observations must be provided")
}
if (num_features != length(feature_types)) {
stop("`feature_types` must have `num_features` total elements")
}
Expand Down Expand Up @@ -175,14 +181,14 @@ ForestModelConfig <- R6::R6Class(
},

#' @description
#' Update root node split probability in tree prior
#' Update minimum number of samples per leaf node in the tree prior
#' @param min_samples_leaf Minimum number of samples in a tree leaf
update_min_samples_leaf = function(min_samples_leaf) {
self$min_samples_leaf <- min_samples_leaf
},

#' @description
#' Update root node split probability in tree prior
#' Update max depth in the tree prior
#' @param max_depth Maximum depth of any tree in the ensemble in the model
update_max_depth = function(max_depth) {
self$max_depth <- max_depth
Expand Down Expand Up @@ -243,6 +249,27 @@ ForestModelConfig <- R6::R6Class(
return(self$variable_weights)
},

#' @description
#' Query number of trees
#' @returns Number of trees in a forest
get_num_trees = function() {
return(self$num_trees)
},

#' @description
#' Query number of features
#' @returns Number of features in a forest model training set
get_num_features = function() {
return(self$num_features)
},

#' @description
#' Query number of observations
#' @returns Number of observations in a forest model training set
get_num_observations = function() {
return(self$num_observations)
},

#' @description
#' Query root node split probability in tree prior for this ForestModelConfig object
#' @returns Root node split probability in tree prior
Expand Down Expand Up @@ -271,6 +298,13 @@ ForestModelConfig <- R6::R6Class(
return(self$max_depth)
},

#' @description
#' Query (integer-coded) type of leaf model
#' @returns Integer coded leaf model type
get_leaf_model_type = function() {
return(self$leaf_model_type)
},

#' @description
#' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object
#' @returns Scale parameter used in Gaussian leaf models
Expand Down
16 changes: 16 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,22 @@ update_max_depth_tree_prior_cpp <- function(tree_prior_ptr, max_depth) {
invisible(.Call(`_stochtree_update_max_depth_tree_prior_cpp`, tree_prior_ptr, max_depth))
}

get_alpha_tree_prior_cpp <- function(tree_prior_ptr) {
.Call(`_stochtree_get_alpha_tree_prior_cpp`, tree_prior_ptr)
}

get_beta_tree_prior_cpp <- function(tree_prior_ptr) {
.Call(`_stochtree_get_beta_tree_prior_cpp`, tree_prior_ptr)
}

get_min_samples_leaf_tree_prior_cpp <- function(tree_prior_ptr) {
.Call(`_stochtree_get_min_samples_leaf_tree_prior_cpp`, tree_prior_ptr)
}

get_max_depth_tree_prior_cpp <- function(tree_prior_ptr) {
.Call(`_stochtree_get_max_depth_tree_prior_cpp`, tree_prior_ptr)
}

forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
}
Expand Down
22 changes: 14 additions & 8 deletions R/kernel.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided,
#' this function will return leaf indices for every sample of a forest.
#' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
#' @return List of vectors. Each vector is of size `num_obs * num_trees`, where `num_obs = nrow(covariates)`
#' @return Vector of size `num_obs * num_trees`, where `num_obs = nrow(covariates)`
#' and `num_trees` is the number of trees in the relevant forest of `model_object`.
#' @export
#'
Expand Down Expand Up @@ -83,8 +83,15 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) {
stop("covariates must be a matrix or dataframe")
}
train_set_metadata <- model_object$train_set_metadata
covariates_processed <- preprocessPredictionData(covariates, train_set_metadata)
if (model_type %in% c("bart", "bcf")) {
train_set_metadata <- model_object$train_set_metadata
covariates_processed <- preprocessPredictionData(covariates, train_set_metadata)
} else {
if (!is.matrix(covariates)) {
stop("covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`")
}
covariates_processed <- covariates
}

# Preprocess forest indices
num_forests <- forest_container$num_samples()
Expand Down Expand Up @@ -199,7 +206,6 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
#' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container.
#'
#' @param model_object Object of type `bartmodel`, `bcfmodel`, or `ForestSamples` corresponding to a BART / BCF model with at least one forest sample, or a low-level `ForestSamples` object.
#' @param covariates Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest.
#' @param forest_type Which forest to use from `model_object`.
#' Valid inputs depend on the model type, and whether or not a
#'
Expand Down Expand Up @@ -228,10 +234,10 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
#' X <- matrix(runif(10*100), ncol = 10)
#' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
#' bart_model <- bart(X, y, num_gfr=0, num_mcmc=10)
#' computeForestMaxLeafIndex(bart_model, X, "mean")
#' computeForestMaxLeafIndex(bart_model, X, "mean", 0)
#' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9))
computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
#' computeForestMaxLeafIndex(bart_model, "mean")
#' computeForestMaxLeafIndex(bart_model, "mean", 0)
#' computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9))
computeForestMaxLeafIndex <- function(model_object, forest_type=NULL, forest_inds=NULL) {
# Extract relevant forest container
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))
Expand Down
43 changes: 43 additions & 0 deletions R/model.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,21 @@ ForestModel <- R6::R6Class(
global_scale <- global_model_config$global_error_variance
cutpoint_grid_size <- forest_model_config$cutpoint_grid_size

# Detect changes to tree prior
if (forest_model_config$alpha != get_alpha_tree_prior_cpp(self$tree_prior_ptr)) {
update_alpha_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$alpha)
}
if (forest_model_config$beta != get_beta_tree_prior_cpp(self$tree_prior_ptr)) {
update_beta_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$beta)
}
if (forest_model_config$min_samples_leaf != get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr)) {
update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$min_samples_leaf)
}
if (forest_model_config$max_depth != get_max_depth_tree_prior_cpp(self$tree_prior_ptr)) {
update_max_depth_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$max_depth)
}

# Run the sampler
if (gfr) {
sample_gfr_one_iteration_cpp(
forest_dataset$data_ptr, residual$data_ptr,
Expand Down Expand Up @@ -165,6 +180,34 @@ ForestModel <- R6::R6Class(
#' @return None
update_max_depth = function(max_depth) {
update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth)
},

#' @description
#' Update alpha in the tree prior
#' @return Value of alpha in the tree prior
get_alpha = function() {
get_alpha_tree_prior_cpp(self$tree_prior_ptr)
},

#' @description
#' Update beta in the tree prior
#' @return Value of beta in the tree prior
get_beta = function() {
get_beta_tree_prior_cpp(self$tree_prior_ptr)
},

#' @description
#' Query min_samples_leaf in the tree prior
#' @return Value of min_samples_leaf in the tree prior
get_min_samples_leaf = function() {
get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr)
},

#' @description
#' Query max_depth in the tree prior
#' @return Value of max_depth in the tree prior
get_max_depth = function() {
get_max_depth_tree_prior_cpp(self$tree_prior_ptr)
}
)
)
Expand Down
2 changes: 1 addition & 1 deletion R/random_effects.R
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ RandomEffectSamples <- R6::R6Class(
#' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`.
#' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`.
#' @param rfx_group_ids Indices of random effects groups in a prediction set
#' @param rfx_basis (Optional ) Basis used for random effects prediction
#' @param rfx_basis (Optional) Basis used for random effects prediction
#' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model.
predict = function(rfx_group_ids, rfx_basis = NULL) {
num_obs = length(rfx_group_ids)
Expand Down
22 changes: 11 additions & 11 deletions demo/debug/causal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,38 +63,38 @@

# Run BCF
bcf_model = BCFModel()
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=1000)

# Inspect the MCMC (BART) samples
forest_preds_y_mcmc = bcf_model.y_hat_test[:,bcf_model.num_gfr:]
forest_preds_y_mcmc = bcf_model.y_hat_test
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)
y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"])
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

forest_preds_tau_mcmc = bcf_model.tau_hat_test[:,bcf_model.num_gfr:]
forest_preds_tau_mcmc = bcf_model.tau_hat_test
tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)
tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test,1), tau_avg_mcmc), axis = 1), columns=["True tau", "Average estimated tau"])
sns.scatterplot(data=tau_df_mcmc, x="Average estimated tau", y="True tau")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

forest_preds_mu_mcmc = bcf_model.mu_hat_test[:,bcf_model.num_gfr:]
forest_preds_mu_mcmc = bcf_model.mu_hat_test
mu_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis = 1, keepdims = True)
mu_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(mu_test,1), mu_avg_mcmc), axis = 1), columns=["True mu", "Average estimated mu"])
sns.scatterplot(data=mu_df_mcmc, x="Average estimated mu", y="True mu")
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
plt.show()

# sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
# sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
# plt.show()
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
plt.show()

# b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
# plt.show()
b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
plt.show()

# Compute RMSEs
y_rmse = np.sqrt(np.mean(np.power(np.expand_dims(y_test,1) - y_avg_mcmc, 2)))
Expand Down
33 changes: 33 additions & 0 deletions demo/debug/kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import numpy as np
from stochtree import Dataset, ForestContainer, compute_forest_leaf_indices

# Create dataset
X = np.array(
[[1.5, 8.7, 1.2],
[2.7, 3.4, 5.4],
[3.6, 1.2, 9.3],
[4.4, 5.4, 10.4],
[5.3, 9.3, 3.6],
[6.1, 10.4, 4.4]]
)
n, p = X.shape
num_trees = 2
output_dim = 1
forest_dataset = Dataset()
forest_dataset.add_covariates(X)
forest_samples = ForestContainer(num_trees, output_dim, True, False)

# Initialize a forest with constant root predictions
forest_samples.add_sample(0.)

# Split the root of the first tree in the ensemble at X[,1] > 4.0
forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.)

# Check that regular and "raw" predictions are the same (since the leaf is constant)
computed_indices = compute_forest_leaf_indices(forest_samples, X)

# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5)

# Check that regular and "raw" predictions are the same (since the leaf is constant)
computed_indices = compute_forest_leaf_indices(forest_samples, X)
2 changes: 0 additions & 2 deletions demo/debug/multivariate_treatment_causal_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,3 @@
# Run BCF
bcf_model = BCFModel()
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)


22 changes: 22 additions & 0 deletions demo/debug/r_comparison_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# R Comparison Demo Script

# Load necessary libraries
import numpy as np
import pandas as pd
from stochtree import BARTModel

# Load data
df = pd.read_csv("debug/data/heterosked_train.csv")
y = df.loc[:,'y'].to_numpy()
X = df.loc[:,['X1','X2','X3','X4','X5','X6','X7','X8','X9','X10']].to_numpy()
y = y.astype(np.float64)
X = X.astype(np.float64)

# Run BART
bart_model = BARTModel()
bart_model.sample(X_train=X, y_train=y, num_gfr=0, num_mcmc=10, general_params={'random_seed': 1234, 'standardize': False, 'sample_sigma2_global': True})

# Inspect the MCMC (BART) samples
y_avg_mcmc = np.squeeze(bart_model.y_hat_train).mean(axis = 1, keepdims = True)
print(y_avg_mcmc[:20])
print(bart_model.global_var_samples)
Loading