Skip to content

Commit 40809c3

Browse files
authored
Merge pull request #151 from StochasticTree/python-update-0.1.1
Updating the Python package for the 0.1.1 release
2 parents 5bbac93 + a5fdcae commit 40809c3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+8205
-2369
lines changed

CMakeLists.txt

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,17 @@ endif()
132132

133133
# Build C++ test program
134134
if(BUILD_TEST)
135-
# Download the GoogleTest dependency if necessary
135+
# Check if user specified a local clone of the GoogleTest repo, use Github repo if not
136+
if (NOT DEFINED GOOGLETEST_GIT_REPO)
137+
set(GOOGLETEST_GIT_REPO https://github.com/google/googletest.git)
138+
endif()
139+
140+
# Fetch and install GoogleTest dependency
136141
include(FetchContent)
137142
FetchContent_Declare(
138143
googletest
139-
GIT_REPOSITORY https://github.com/google/googletest.git
140-
GIT_TAG e2239ee6043f73722e7aa812a459f54a28552929 # release-1.14.0
144+
GIT_REPOSITORY ${GOOGLETEST_GIT_REPO}
145+
GIT_TAG 6910c9d9165801d8827d628cb72eb7ea9dd538c5 # release-1.16.0
141146
)
142147
# For Windows: Prevent overriding the parent project's compiler/linker settings
143148
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)

NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# stochtree 0.1.2
2+
3+
* Fixed indexing bug in cleanup of grow-from-root (GFR) samples in BART and BCF models
4+
* Avoid using covariate preprocessor in `computeForestLeafIndices` function when a `ForestSamples` object is provided
5+
16
# stochtree 0.1.1
27

38
* Fixed initialization bug in several R package code examples for random effects models

R/bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
445445
}
446446
if (has_rfx_test) {
447447
if (is.null(rfx_basis_test)) {
448-
if (!is.null(rfx_basis_train)) {
448+
if (has_basis_rfx) {
449449
stop("Random effects basis provided for training set, must also be provided for the test set")
450450
}
451451
rfx_basis_test <- matrix(rep(1,nrow(X_test)), nrow = nrow(X_test), ncol = 1)

R/config.R

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ ForestModelConfig <- R6::R6Class(
6161

6262
#' Create a new ForestModelConfig object.
6363
#'
64-
#' @param feature_types Vector of integer-coded feature types (integers where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
64+
#' @param feature_types Vector of integer-coded feature types (where 0 = numeric, 1 = ordered categorical, 2 = unordered categorical)
6565
#' @param num_trees Number of trees in the forest being sampled
6666
#' @param num_features Number of features in training dataset
6767
#' @param num_observations Number of observations in training dataset
@@ -98,6 +98,12 @@ ForestModelConfig <- R6::R6Class(
9898
warning("`variable_weights` not provided, will be assumed to be equal-weighted")
9999
variable_weights <- rep(1/num_features, num_features)
100100
}
101+
if (is.null(num_trees)) {
102+
stop("num_trees must be provided")
103+
}
104+
if (is.null(num_observations)) {
105+
stop("num_observations must be provided")
106+
}
101107
if (num_features != length(feature_types)) {
102108
stop("`feature_types` must have `num_features` total elements")
103109
}
@@ -175,14 +181,14 @@ ForestModelConfig <- R6::R6Class(
175181
},
176182

177183
#' @description
178-
#' Update root node split probability in tree prior
184+
#' Update minimum number of samples per leaf node in the tree prior
179185
#' @param min_samples_leaf Minimum number of samples in a tree leaf
180186
update_min_samples_leaf = function(min_samples_leaf) {
181187
self$min_samples_leaf <- min_samples_leaf
182188
},
183189

184190
#' @description
185-
#' Update root node split probability in tree prior
191+
#' Update max depth in the tree prior
186192
#' @param max_depth Maximum depth of any tree in the ensemble in the model
187193
update_max_depth = function(max_depth) {
188194
self$max_depth <- max_depth
@@ -243,6 +249,27 @@ ForestModelConfig <- R6::R6Class(
243249
return(self$variable_weights)
244250
},
245251

252+
#' @description
253+
#' Query number of trees
254+
#' @returns Number of trees in a forest
255+
get_num_trees = function() {
256+
return(self$num_trees)
257+
},
258+
259+
#' @description
260+
#' Query number of features
261+
#' @returns Number of features in a forest model training set
262+
get_num_features = function() {
263+
return(self$num_features)
264+
},
265+
266+
#' @description
267+
#' Query number of observations
268+
#' @returns Number of observations in a forest model training set
269+
get_num_observations = function() {
270+
return(self$num_observations)
271+
},
272+
246273
#' @description
247274
#' Query root node split probability in tree prior for this ForestModelConfig object
248275
#' @returns Root node split probability in tree prior
@@ -271,6 +298,13 @@ ForestModelConfig <- R6::R6Class(
271298
return(self$max_depth)
272299
},
273300

301+
#' @description
302+
#' Query (integer-coded) type of leaf model
303+
#' @returns Integer coded leaf model type
304+
get_leaf_model_type = function() {
305+
return(self$leaf_model_type)
306+
},
307+
274308
#' @description
275309
#' Query scale parameter used in Gaussian leaf models for this ForestModelConfig object
276310
#' @returns Scale parameter used in Gaussian leaf models

R/cpp11.R

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,22 @@ update_max_depth_tree_prior_cpp <- function(tree_prior_ptr, max_depth) {
596596
invisible(.Call(`_stochtree_update_max_depth_tree_prior_cpp`, tree_prior_ptr, max_depth))
597597
}
598598

599+
get_alpha_tree_prior_cpp <- function(tree_prior_ptr) {
600+
.Call(`_stochtree_get_alpha_tree_prior_cpp`, tree_prior_ptr)
601+
}
602+
603+
get_beta_tree_prior_cpp <- function(tree_prior_ptr) {
604+
.Call(`_stochtree_get_beta_tree_prior_cpp`, tree_prior_ptr)
605+
}
606+
607+
get_min_samples_leaf_tree_prior_cpp <- function(tree_prior_ptr) {
608+
.Call(`_stochtree_get_min_samples_leaf_tree_prior_cpp`, tree_prior_ptr)
609+
}
610+
611+
get_max_depth_tree_prior_cpp <- function(tree_prior_ptr) {
612+
.Call(`_stochtree_get_max_depth_tree_prior_cpp`, tree_prior_ptr)
613+
}
614+
599615
forest_tracker_cpp <- function(data, feature_types, num_trees, n) {
600616
.Call(`_stochtree_forest_tracker_cpp`, data, feature_types, num_trees, n)
601617
}

R/kernel.R

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
#' @param forest_inds (Optional) Indices of the forest sample(s) for which to compute leaf indices. If not provided,
3636
#' this function will return leaf indices for every sample of a forest.
3737
#' This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
38-
#' @return List of vectors. Each vector is of size `num_obs * num_trees`, where `num_obs = nrow(covariates)`
38+
#' @return Vector of size `num_obs * num_trees`, where `num_obs = nrow(covariates)`
3939
#' and `num_trees` is the number of trees in the relevant forest of `model_object`.
4040
#' @export
4141
#'
@@ -83,8 +83,15 @@ computeForestLeafIndices <- function(model_object, covariates, forest_type=NULL,
8383
if ((!is.data.frame(covariates)) && (!is.matrix(covariates))) {
8484
stop("covariates must be a matrix or dataframe")
8585
}
86-
train_set_metadata <- model_object$train_set_metadata
87-
covariates_processed <- preprocessPredictionData(covariates, train_set_metadata)
86+
if (model_type %in% c("bart", "bcf")) {
87+
train_set_metadata <- model_object$train_set_metadata
88+
covariates_processed <- preprocessPredictionData(covariates, train_set_metadata)
89+
} else {
90+
if (!is.matrix(covariates)) {
91+
stop("covariates must be a matrix since no covariate preprocessor is stored in a `ForestSamples` object provided as `model_object`")
92+
}
93+
covariates_processed <- covariates
94+
}
8895

8996
# Preprocess forest indices
9097
num_forests <- forest_container$num_samples()
@@ -199,7 +206,6 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
199206
#' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container.
200207
#'
201208
#' @param model_object Object of type `bartmodel`, `bcfmodel`, or `ForestSamples` corresponding to a BART / BCF model with at least one forest sample, or a low-level `ForestSamples` object.
202-
#' @param covariates Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest.
203209
#' @param forest_type Which forest to use from `model_object`.
204210
#' Valid inputs depend on the model type, and whether or not a
205211
#'
@@ -228,10 +234,10 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
228234
#' X <- matrix(runif(10*100), ncol = 10)
229235
#' y <- -5 + 10*(X[,1] > 0.5) + rnorm(100)
230236
#' bart_model <- bart(X, y, num_gfr=0, num_mcmc=10)
231-
#' computeForestMaxLeafIndex(bart_model, X, "mean")
232-
#' computeForestMaxLeafIndex(bart_model, X, "mean", 0)
233-
#' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9))
234-
computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
237+
#' computeForestMaxLeafIndex(bart_model, "mean")
238+
#' computeForestMaxLeafIndex(bart_model, "mean", 0)
239+
#' computeForestMaxLeafIndex(bart_model, "mean", c(1,3,9))
240+
computeForestMaxLeafIndex <- function(model_object, forest_type=NULL, forest_inds=NULL) {
235241
# Extract relevant forest container
236242
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
237243
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))

R/model.R

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,21 @@ ForestModel <- R6::R6Class(
8585
global_scale <- global_model_config$global_error_variance
8686
cutpoint_grid_size <- forest_model_config$cutpoint_grid_size
8787

88+
# Detect changes to tree prior
89+
if (forest_model_config$alpha != get_alpha_tree_prior_cpp(self$tree_prior_ptr)) {
90+
update_alpha_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$alpha)
91+
}
92+
if (forest_model_config$beta != get_beta_tree_prior_cpp(self$tree_prior_ptr)) {
93+
update_beta_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$beta)
94+
}
95+
if (forest_model_config$min_samples_leaf != get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr)) {
96+
update_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$min_samples_leaf)
97+
}
98+
if (forest_model_config$max_depth != get_max_depth_tree_prior_cpp(self$tree_prior_ptr)) {
99+
update_max_depth_tree_prior_cpp(self$tree_prior_ptr, forest_model_config$max_depth)
100+
}
101+
102+
# Run the sampler
88103
if (gfr) {
89104
sample_gfr_one_iteration_cpp(
90105
forest_dataset$data_ptr, residual$data_ptr,
@@ -165,6 +180,34 @@ ForestModel <- R6::R6Class(
165180
#' @return None
166181
update_max_depth = function(max_depth) {
167182
update_max_depth_tree_prior_cpp(self$tree_prior_ptr, max_depth)
183+
},
184+
185+
#' @description
186+
#' Update alpha in the tree prior
187+
#' @return Value of alpha in the tree prior
188+
get_alpha = function() {
189+
get_alpha_tree_prior_cpp(self$tree_prior_ptr)
190+
},
191+
192+
#' @description
193+
#' Update beta in the tree prior
194+
#' @return Value of beta in the tree prior
195+
get_beta = function() {
196+
get_beta_tree_prior_cpp(self$tree_prior_ptr)
197+
},
198+
199+
#' @description
200+
#' Query min_samples_leaf in the tree prior
201+
#' @return Value of min_samples_leaf in the tree prior
202+
get_min_samples_leaf = function() {
203+
get_min_samples_leaf_tree_prior_cpp(self$tree_prior_ptr)
204+
},
205+
206+
#' @description
207+
#' Query max_depth in the tree prior
208+
#' @return Value of max_depth in the tree prior
209+
get_max_depth = function() {
210+
get_max_depth_tree_prior_cpp(self$tree_prior_ptr)
168211
}
169212
)
170213
)

R/random_effects.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ RandomEffectSamples <- R6::R6Class(
9393
#' Predict random effects for each observation implied by `rfx_group_ids` and `rfx_basis`.
9494
#' If a random effects model is "intercept-only" the `rfx_basis` will be a vector of ones of size `length(rfx_group_ids)`.
9595
#' @param rfx_group_ids Indices of random effects groups in a prediction set
96-
#' @param rfx_basis (Optional ) Basis used for random effects prediction
96+
#' @param rfx_basis (Optional) Basis used for random effects prediction
9797
#' @return Matrix with as many rows as observations provided and as many columns as samples drawn of the model.
9898
predict = function(rfx_group_ids, rfx_basis = NULL) {
9999
num_obs = length(rfx_group_ids)

demo/debug/causal_inference.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,38 +63,38 @@
6363

6464
# Run BCF
6565
bcf_model = BCFModel()
66-
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)
66+
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=1000)
6767

6868
# Inspect the MCMC (BART) samples
69-
forest_preds_y_mcmc = bcf_model.y_hat_test[:,bcf_model.num_gfr:]
69+
forest_preds_y_mcmc = bcf_model.y_hat_test
7070
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)
7171
y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"])
7272
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
7373
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
7474
plt.show()
7575

76-
forest_preds_tau_mcmc = bcf_model.tau_hat_test[:,bcf_model.num_gfr:]
76+
forest_preds_tau_mcmc = bcf_model.tau_hat_test
7777
tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)
7878
tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test,1), tau_avg_mcmc), axis = 1), columns=["True tau", "Average estimated tau"])
7979
sns.scatterplot(data=tau_df_mcmc, x="Average estimated tau", y="True tau")
8080
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
8181
plt.show()
8282

83-
forest_preds_mu_mcmc = bcf_model.mu_hat_test[:,bcf_model.num_gfr:]
83+
forest_preds_mu_mcmc = bcf_model.mu_hat_test
8484
mu_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis = 1, keepdims = True)
8585
mu_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(mu_test,1), mu_avg_mcmc), axis = 1), columns=["True mu", "Average estimated mu"])
8686
sns.scatterplot(data=mu_df_mcmc, x="Average estimated mu", y="True mu")
8787
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
8888
plt.show()
8989

90-
# sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
91-
# sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
92-
# plt.show()
90+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
91+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
92+
plt.show()
9393

94-
# b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
95-
# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
96-
# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
97-
# plt.show()
94+
b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
95+
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
96+
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
97+
plt.show()
9898

9999
# Compute RMSEs
100100
y_rmse = np.sqrt(np.mean(np.power(np.expand_dims(y_test,1) - y_avg_mcmc, 2)))

demo/debug/kernel.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
from stochtree import Dataset, ForestContainer, compute_forest_leaf_indices
3+
4+
# Create dataset
5+
X = np.array(
6+
[[1.5, 8.7, 1.2],
7+
[2.7, 3.4, 5.4],
8+
[3.6, 1.2, 9.3],
9+
[4.4, 5.4, 10.4],
10+
[5.3, 9.3, 3.6],
11+
[6.1, 10.4, 4.4]]
12+
)
13+
n, p = X.shape
14+
num_trees = 2
15+
output_dim = 1
16+
forest_dataset = Dataset()
17+
forest_dataset.add_covariates(X)
18+
forest_samples = ForestContainer(num_trees, output_dim, True, False)
19+
20+
# Initialize a forest with constant root predictions
21+
forest_samples.add_sample(0.)
22+
23+
# Split the root of the first tree in the ensemble at X[,1] > 4.0
24+
forest_samples.add_numeric_split(0, 0, 0, 0, 4.0, -5., 5.)
25+
26+
# Check that regular and "raw" predictions are the same (since the leaf is constant)
27+
computed_indices = compute_forest_leaf_indices(forest_samples, X)
28+
29+
# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
30+
forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5)
31+
32+
# Check that regular and "raw" predictions are the same (since the leaf is constant)
33+
computed_indices = compute_forest_leaf_indices(forest_samples, X)

demo/debug/multivariate_treatment_causal_inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,3 @@
4444
# Run BCF
4545
bcf_model = BCFModel()
4646
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)
47-
48-

demo/debug/r_comparison_debug.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# R Comparison Demo Script
2+
3+
# Load necessary libraries
4+
import numpy as np
5+
import pandas as pd
6+
from stochtree import BARTModel
7+
8+
# Load data
9+
df = pd.read_csv("debug/data/heterosked_train.csv")
10+
y = df.loc[:,'y'].to_numpy()
11+
X = df.loc[:,['X1','X2','X3','X4','X5','X6','X7','X8','X9','X10']].to_numpy()
12+
y = y.astype(np.float64)
13+
X = X.astype(np.float64)
14+
15+
# Run BART
16+
bart_model = BARTModel()
17+
bart_model.sample(X_train=X, y_train=y, num_gfr=0, num_mcmc=10, general_params={'random_seed': 1234, 'standardize': False, 'sample_sigma2_global': True})
18+
19+
# Inspect the MCMC (BART) samples
20+
y_avg_mcmc = np.squeeze(bart_model.y_hat_train).mean(axis = 1, keepdims = True)
21+
print(y_avg_mcmc[:20])
22+
print(bart_model.global_var_samples)

0 commit comments

Comments
 (0)