Skip to content

Commit 8cf53ee

Browse files
committed
Initial commit of warm-start interface in Python
1 parent dc7bad9 commit 8cf53ee

File tree

3 files changed

+218
-1
lines changed

3 files changed

+218
-1
lines changed

R/bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ bart <- function(X_train, y_train, leaf_basis_train = NULL, rfx_group_ids_train
206206
if (previous_bart_model$model_params$include_mean_forest) {
207207
previous_forest_samples_mean <- previous_bart_model$mean_forests
208208
} else previous_forest_samples_mean <- NULL
209-
if (previous_bart_model$model_params$include_mean_forest) {
209+
if (previous_bart_model$model_params$include_variance_forest) {
210210
previous_forest_samples_variance <- previous_bart_model$variance_forests
211211
} else previous_forest_samples_variance <- NULL
212212
if (previous_bart_model$model_params$sample_sigma_global) {

demo/debug/multi_chain.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# Multi Chain Demo Script
2+
3+
# Load necessary libraries
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import pandas as pd
7+
import seaborn as sns
8+
from sklearn.model_selection import train_test_split
9+
10+
from stochtree import BARTModel
11+
12+
# Generate sample data
13+
# RNG
14+
random_seed = 1234
15+
rng = np.random.default_rng(random_seed)
16+
17+
# Generate covariates and basis
18+
n = 500
19+
p_X = 10
20+
p_W = 1
21+
X = rng.uniform(0, 1, (n, p_X))
22+
W = rng.uniform(0, 1, (n, p_W))
23+
24+
# Define the outcome mean function
25+
def outcome_mean(X, W):
26+
return np.where(
27+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
28+
-7.5 * W[:, 0],
29+
np.where(
30+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
31+
-2.5 * W[:, 0],
32+
np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]),
33+
),
34+
)
35+
36+
# Generate outcome
37+
f_XW = outcome_mean(X, W)
38+
epsilon = rng.normal(0, 1, n)
39+
y = f_XW + epsilon
40+
41+
# Test-train split
42+
sample_inds = np.arange(n)
43+
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5, random_state=random_seed)
44+
X_train = X[train_inds, :]
45+
X_test = X[test_inds, :]
46+
basis_train = W[train_inds, :]
47+
basis_test = W[test_inds, :]
48+
y_train = y[train_inds]
49+
y_test = y[test_inds]
50+
51+
# Run the GFR algorithm for a small number of iterations
52+
general_model_params = {"random_seed": -1}
53+
mean_forest_model_params = {"num_trees": 20}
54+
num_warmstart = 10
55+
num_mcmc = 10
56+
bart_model = BARTModel()
57+
bart_model.sample(
58+
X_train=X_train,
59+
y_train=y_train,
60+
leaf_basis_train=basis_train,
61+
X_test=X_test,
62+
leaf_basis_test=basis_test,
63+
num_gfr=num_warmstart,
64+
num_mcmc=0,
65+
general_params=general_model_params,
66+
mean_forest_params=mean_forest_model_params
67+
)
68+
bart_model_json = bart_model.to_json()
69+
70+
# Run several BART MCMC samples from the last GFR forest
71+
bart_model_2 = BARTModel()
72+
bart_model_2.sample(
73+
X_train=X_train,
74+
y_train=y_train,
75+
leaf_basis_train=basis_train,
76+
X_test=X_test,
77+
leaf_basis_test=basis_test,
78+
num_gfr=0,
79+
num_mcmc=num_mcmc,
80+
previous_model_json=bart_model_json,
81+
previous_model_warmstart_sample_num=num_warmstart-1,
82+
general_params=general_model_params,
83+
mean_forest_params=mean_forest_model_params
84+
)
85+
86+
# Run several BART MCMC samples from the second-to-last GFR forest
87+
bart_model_3 = BARTModel()
88+
bart_model_3.sample(
89+
X_train=X_train,
90+
y_train=y_train,
91+
leaf_basis_train=basis_train,
92+
X_test=X_test,
93+
leaf_basis_test=basis_test,
94+
num_gfr=0,
95+
num_mcmc=num_mcmc,
96+
previous_model_json=bart_model_json,
97+
previous_model_warmstart_sample_num=num_warmstart-2,
98+
general_params=general_model_params,
99+
mean_forest_params=mean_forest_model_params
100+
)
101+
102+
# Run several BART MCMC samples from root
103+
bart_model_4 = BARTModel()
104+
bart_model_4.sample(
105+
X_train=X_train,
106+
y_train=y_train,
107+
leaf_basis_train=basis_train,
108+
X_test=X_test,
109+
leaf_basis_test=basis_test,
110+
num_gfr=0,
111+
num_mcmc=num_mcmc,
112+
general_params=general_model_params,
113+
mean_forest_params=mean_forest_model_params
114+
)
115+
116+
# Inspect the model outputs
117+
y_hat_mcmc_2 = bart_model_2.predict(X_test, basis_test)
118+
y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True)
119+
y_hat_mcmc_3 = bart_model_3.predict(X_test, basis_test)
120+
y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True)
121+
y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test)
122+
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
123+
y_df = pd.DataFrame(
124+
np.concatenate((y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)), axis=1),
125+
columns=["First Chain", "Second Chain", "Third Chain", "Outcome"],
126+
)
127+
128+
# Compare first warm-start chain to root chain with equal number of MCMC draws
129+
sns.scatterplot(data=y_df, x="First Chain", y="Third Chain")
130+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
131+
plt.show()
132+
133+
# Compare first warm-start chain to outcome
134+
sns.scatterplot(data=y_df, x="First Chain", y="Outcome")
135+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
136+
plt.show()
137+
138+
# Compare root chain to outcome
139+
sns.scatterplot(data=y_df, x="Third Chain", y="Outcome")
140+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
141+
plt.show()
142+
143+
# Compute RMSEs
144+
rmse_1 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_2)-y_test)*(np.squeeze(y_avg_mcmc_2)-y_test)))
145+
rmse_2 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_3)-y_test)*(np.squeeze(y_avg_mcmc_3)-y_test)))
146+
rmse_3 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_4)-y_test)*(np.squeeze(y_avg_mcmc_4)-y_test)))
147+
print("Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(rmse_1, rmse_2, rmse_3))

stochtree/bart.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ def sample(
7777
general_params: Optional[Dict[str, Any]] = None,
7878
mean_forest_params: Optional[Dict[str, Any]] = None,
7979
variance_forest_params: Optional[Dict[str, Any]] = None,
80+
previous_model_json: Optional[str] = None,
81+
previous_model_warmstart_sample_num: Optional[int] = None,
8082
) -> None:
8183
"""Runs a BART sampler on provided training set. Predictions will be cached for the training set and (if provided) the test set.
8284
Does not require a leaf regression basis.
@@ -154,6 +156,11 @@ def sample(
154156
* `var_forest_prior_scale` (`float`): Scale parameter in the [optional] `IG(var_forest_prior_shape, var_forest_prior_scale)` conditional error variance forest (which is only sampled if `num_trees > 0`). Calibrated internally as `num_trees / leaf_prior_calibration_param^2` if not set here.
155157
* `keep_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be included in the variance forest. Defaults to `None`.
156158
* `drop_vars` (`list` or `np.array`): Vector of variable names or column indices denoting variables that should be excluded from the variance forest. Defaults to `None`. If both `drop_vars` and `keep_vars` are set, `drop_vars` will be ignored.
159+
160+
previous_model_json : str, optional
161+
JSON string containing a previous BART model. This can be used to "continue" a sampler interactively after inspecting the samples or to run parallel chains "warm-started" from existing forest samples. Defaults to `None`.
162+
previous_model_warmstart_sample_num : int, optional
163+
Sample number from `previous_model_json` that will be used to warmstart this BART sampler. Zero-indexed (so that the first sample is used for warm-start by setting `previous_model_warmstart_sample_num = 0`). Defaults to `None`.
157164
158165
Returns
159166
-------
@@ -612,6 +619,51 @@ def sample(
612619
else:
613620
variable_subset_variance = [i for i in range(X_train.shape[1])]
614621

622+
# Check if previous model JSON is provided and parse it if so
623+
has_prev_model = previous_model_json is not None
624+
if has_prev_model:
625+
if num_gfr > 0:
626+
if num_mcmc == 0:
627+
raise ValueError("A previous model is being used to initialize this sampler, so `num_mcmc` must be greater than zero")
628+
else:
629+
warnings.warn("A previous model is being used to initialize this sampler, so num_gfr will be ignored and the MCMC sampler will be run from the previous samples")
630+
previous_bart_model = BARTModel()
631+
previous_bart_model.from_json(previous_model_json)
632+
previous_y_bar = previous_bart_model.y_bar
633+
previous_y_scale = previous_bart_model.y_std
634+
previous_model_num_samples = previous_bart_model.num_samples
635+
if previous_bart_model.include_mean_forest:
636+
previous_forest_samples_mean = previous_bart_model.forest_container_mean
637+
else:
638+
previous_forest_samples_mean = None
639+
if previous_bart_model.include_variance_forest:
640+
previous_forest_samples_variance = previous_bart_model.forest_container_variance
641+
else:
642+
previous_forest_samples_variance = None
643+
if previous_bart_model.sample_sigma_global:
644+
previous_global_var_samples = previous_bart_model.global_var_samples / (previous_y_scale * previous_y_scale)
645+
else:
646+
previous_global_var_samples = None
647+
if previous_bart_model.sample_sigma_leaf:
648+
previous_leaf_var_samples = previous_bart_model.leaf_scale_samples
649+
else:
650+
previous_leaf_var_samples = None
651+
if previous_bart_model.has_rfx:
652+
previous_rfx_samples = previous_bart_model.rfx_container
653+
else:
654+
previous_rfx_samples = None
655+
if previous_model_warmstart_sample_num + 1 > previous_model_num_samples:
656+
raise ValueError("`previous_model_warmstart_sample_num` exceeds the number of samples in `previous_model_json`")
657+
else:
658+
previous_y_bar = None
659+
previous_y_scale = None
660+
previous_global_var_samples = None
661+
previous_leaf_var_samples = None
662+
previous_rfx_samples = None
663+
previous_forest_samples_mean = None
664+
previous_forest_samples_variance = None
665+
previous_model_num_samples = 0
666+
615667
# Update variable weights if the covariates have been resized (by e.g. one-hot encoding)
616668
if X_train_processed.shape[1] != X_train.shape[1]:
617669
variable_counts = [
@@ -992,6 +1044,22 @@ def sample(
9921044
)
9931045
if sample_sigma_global:
9941046
current_sigma2 = self.global_var_samples[forest_ind]
1047+
elif has_prev_model:
1048+
if self.include_mean_forest:
1049+
active_forest_mean.reset(previous_bart_model.forest_container_mean, previous_model_warmstart_sample_num)
1050+
forest_sampler_mean.reconstitute_from_forest(active_forest_mean, forest_dataset_train, residual_train, True)
1051+
if sample_sigma_leaf and previous_leaf_var_samples is not None:
1052+
leaf_scale_double = previous_leaf_var_samples[previous_model_warmstart_sample_num]
1053+
current_leaf_scale[0, 0] = leaf_scale_double
1054+
forest_model_config_mean.update_leaf_model_scale(leaf_scale_double)
1055+
if self.include_variance_forest:
1056+
active_forest_variance.reset(previous_bart_model.forest_container_variance, previous_model_warmstart_sample_num)
1057+
forest_sampler_variance.reconstitute_from_forest(active_forest_variance, forest_dataset_train, residual_train, True)
1058+
# if self.has_rfx:
1059+
# pass
1060+
if self.sample_sigma_global:
1061+
current_sigma2 = previous_global_var_samples[previous_model_warmstart_sample_num]
1062+
global_model_config.update_global_error_variance(current_sigma2)
9951063
else:
9961064
if self.include_mean_forest:
9971065
active_forest_mean.reset_root()
@@ -1069,12 +1137,14 @@ def sample(
10691137
current_sigma2 = global_var_model.sample_one_iteration(
10701138
residual_train, cpp_rng, a_global, b_global
10711139
)
1140+
global_model_config.update_global_error_variance(current_sigma2)
10721141
if keep_sample:
10731142
self.global_var_samples[sample_counter] = current_sigma2
10741143
if self.sample_sigma_leaf:
10751144
current_leaf_scale[0, 0] = leaf_var_model.sample_one_iteration(
10761145
active_forest_mean, cpp_rng, a_leaf, b_leaf
10771146
)
1147+
forest_model_config_mean.update_leaf_model_scale(current_leaf_scale)
10781148
if keep_sample:
10791149
self.leaf_scale_samples[sample_counter] = (
10801150
current_leaf_scale[0, 0]

0 commit comments

Comments
 (0)