Skip to content

Commit daa9243

Browse files
committed
Updated python interface to support parallel multi-chain
1 parent 294d01e commit daa9243

File tree

10 files changed

+1560
-209
lines changed

10 files changed

+1560
-209
lines changed

R/bart.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1853,7 +1853,7 @@ createBARTModelFromCombinedJsonString <- function(json_string_list){
18531853
}
18541854

18551855
# Unpack covariate preprocessor
1856-
preprocessor_metadata_string <- json_object$get_string("preprocessor_metadata")
1856+
preprocessor_metadata_string <- json_object_default$get_string("preprocessor_metadata")
18571857
output[["train_set_metadata"]] <- createPreprocessorFromJsonString(
18581858
preprocessor_metadata_string
18591859
)

demo/debug/multi_chain.py

Lines changed: 34 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
X = rng.uniform(0, 1, (n, p_X))
2222
W = rng.uniform(0, 1, (n, p_W))
2323

24+
2425
# Define the outcome mean function
2526
def outcome_mean(X, W):
2627
return np.where(
@@ -33,14 +34,17 @@ def outcome_mean(X, W):
3334
),
3435
)
3536

37+
3638
# Generate outcome
3739
f_XW = outcome_mean(X, W)
3840
epsilon = rng.normal(0, 1, n)
3941
y = f_XW + epsilon
4042

4143
# Test-train split
4244
sample_inds = np.arange(n)
43-
train_inds, test_inds = train_test_split(sample_inds, test_size=0.5, random_state=random_seed)
45+
train_inds, test_inds = train_test_split(
46+
sample_inds, test_size=0.5, random_state=random_seed
47+
)
4448
X_train = X[train_inds, :]
4549
X_test = X[test_inds, :]
4650
basis_train = W[train_inds, :]
@@ -61,9 +65,9 @@ def outcome_mean(X, W):
6165
X_test=X_test,
6266
leaf_basis_test=basis_test,
6367
num_gfr=num_warmstart,
64-
num_mcmc=0,
65-
general_params=general_model_params,
66-
mean_forest_params=mean_forest_model_params
68+
num_mcmc=0,
69+
general_params=general_model_params,
70+
mean_forest_params=mean_forest_model_params,
6771
)
6872
bart_model_json = bart_model.to_json()
6973

@@ -78,9 +82,9 @@ def outcome_mean(X, W):
7882
num_gfr=0,
7983
num_mcmc=num_mcmc,
8084
previous_model_json=bart_model_json,
81-
previous_model_warmstart_sample_num=num_warmstart-1,
82-
general_params=general_model_params,
83-
mean_forest_params=mean_forest_model_params
85+
previous_model_warmstart_sample_num=num_warmstart - 1,
86+
general_params=general_model_params,
87+
mean_forest_params=mean_forest_model_params,
8488
)
8589

8690
# Run several BART MCMC samples from the second-to-last GFR forest
@@ -94,9 +98,9 @@ def outcome_mean(X, W):
9498
num_gfr=0,
9599
num_mcmc=num_mcmc,
96100
previous_model_json=bart_model_json,
97-
previous_model_warmstart_sample_num=num_warmstart-2,
98-
general_params=general_model_params,
99-
mean_forest_params=mean_forest_model_params
101+
previous_model_warmstart_sample_num=num_warmstart - 2,
102+
general_params=general_model_params,
103+
mean_forest_params=mean_forest_model_params,
100104
)
101105

102106
# Run several BART MCMC samples from root
@@ -109,8 +113,8 @@ def outcome_mean(X, W):
109113
leaf_basis_test=basis_test,
110114
num_gfr=0,
111115
num_mcmc=num_mcmc,
112-
general_params=general_model_params,
113-
mean_forest_params=mean_forest_model_params
116+
general_params=general_model_params,
117+
mean_forest_params=mean_forest_model_params,
114118
)
115119

116120
# Inspect the model outputs
@@ -121,7 +125,10 @@ def outcome_mean(X, W):
121125
y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test)
122126
y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True)
123127
y_df = pd.DataFrame(
124-
np.concatenate((y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)), axis=1),
128+
np.concatenate(
129+
(y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)),
130+
axis=1,
131+
),
125132
columns=["First Chain", "Second Chain", "Third Chain", "Outcome"],
126133
)
127134

@@ -141,7 +148,17 @@ def outcome_mean(X, W):
141148
plt.show()
142149

143150
# Compute RMSEs
144-
rmse_1 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_2)-y_test)*(np.squeeze(y_avg_mcmc_2)-y_test)))
145-
rmse_2 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_3)-y_test)*(np.squeeze(y_avg_mcmc_3)-y_test)))
146-
rmse_3 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_4)-y_test)*(np.squeeze(y_avg_mcmc_4)-y_test)))
147-
print("Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(rmse_1, rmse_2, rmse_3))
151+
rmse_1 = np.sqrt(
152+
np.mean((np.squeeze(y_avg_mcmc_2) - y_test) * (np.squeeze(y_avg_mcmc_2) - y_test))
153+
)
154+
rmse_2 = np.sqrt(
155+
np.mean((np.squeeze(y_avg_mcmc_3) - y_test) * (np.squeeze(y_avg_mcmc_3) - y_test))
156+
)
157+
rmse_3 = np.sqrt(
158+
np.mean((np.squeeze(y_avg_mcmc_4) - y_test) * (np.squeeze(y_avg_mcmc_4) - y_test))
159+
)
160+
print(
161+
"Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(
162+
rmse_1, rmse_2, rmse_3
163+
)
164+
)

demo/debug/parallel_multi_chain.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# Multi Chain Demo Script
2+
3+
# Load necessary libraries
4+
from multiprocessing import Pool, cpu_count
5+
6+
import matplotlib.pyplot as plt
7+
import numpy as np
8+
import pandas as pd
9+
import seaborn as sns
10+
from sklearn.model_selection import train_test_split
11+
12+
from stochtree import BARTModel
13+
14+
15+
def fit_bart(
16+
model_string,
17+
X_train,
18+
y_train,
19+
basis_train,
20+
X_test,
21+
basis_test,
22+
num_mcmc,
23+
gen_param_list,
24+
mean_list,
25+
i,
26+
):
27+
bart_model = BARTModel()
28+
bart_model.sample(
29+
X_train=X_train,
30+
y_train=y_train,
31+
leaf_basis_train=basis_train,
32+
X_test=X_test,
33+
leaf_basis_test=basis_test,
34+
num_gfr=0,
35+
num_mcmc=num_mcmc,
36+
previous_model_json=model_string,
37+
previous_model_warmstart_sample_num=i,
38+
general_params=gen_param_list,
39+
mean_forest_params=mean_list,
40+
)
41+
return (bart_model.to_json(), bart_model.y_hat_test)
42+
43+
44+
def bart_warmstart_parallel(X_train, y_train, basis_train, X_test, basis_test):
45+
# Run the GFR algorithm for a small number of iterations
46+
general_model_params = {"random_seed": -1}
47+
mean_forest_model_params = {"num_trees": 100}
48+
num_warmstart = 10
49+
num_mcmc = 100
50+
bart_model = BARTModel()
51+
bart_model.sample(
52+
X_train=X_train,
53+
y_train=y_train,
54+
leaf_basis_train=basis_train,
55+
X_test=X_test,
56+
leaf_basis_test=basis_test,
57+
num_gfr=num_warmstart,
58+
num_mcmc=0,
59+
general_params=general_model_params,
60+
mean_forest_params=mean_forest_model_params,
61+
)
62+
bart_model_json = bart_model.to_json()
63+
64+
# Warm-start multiple BART fits from a different GFR forest
65+
process_tasks = [
66+
(
67+
bart_model_json,
68+
X_train,
69+
y_train,
70+
basis_train,
71+
X_test,
72+
basis_test,
73+
num_mcmc,
74+
general_model_params,
75+
mean_forest_model_params,
76+
i,
77+
)
78+
for i in range(4)
79+
]
80+
num_processes = cpu_count()
81+
with Pool(processes=num_processes) as pool:
82+
results = pool.starmap(fit_bart, process_tasks)
83+
84+
# Extract separate outputs as separate lists
85+
bart_model_json_list, bart_model_pred_list = zip(*results)
86+
87+
# Process results
88+
combined_bart_model = BARTModel()
89+
combined_bart_model.from_json_string_list(bart_model_json_list)
90+
combined_bart_preds = bart_model_pred_list[0]
91+
for i in range(1, len(bart_model_pred_list)):
92+
combined_bart_preds = np.concatenate(
93+
(combined_bart_preds, bart_model_pred_list[i]), axis=1
94+
)
95+
96+
return (combined_bart_model, combined_bart_preds)
97+
98+
99+
if __name__ == "__main__":
100+
# RNG
101+
random_seed = 1234
102+
rng = np.random.default_rng(random_seed)
103+
104+
# Generate covariates and basis
105+
n = 1000
106+
p_X = 10
107+
p_W = 1
108+
X = rng.uniform(0, 1, (n, p_X))
109+
W = rng.uniform(0, 1, (n, p_W))
110+
111+
# Define the outcome mean function
112+
def outcome_mean(X, W):
113+
return np.where(
114+
(X[:, 0] >= 0.0) & (X[:, 0] < 0.25),
115+
-7.5 * W[:, 0],
116+
np.where(
117+
(X[:, 0] >= 0.25) & (X[:, 0] < 0.5),
118+
-2.5 * W[:, 0],
119+
np.where(
120+
(X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]
121+
),
122+
),
123+
)
124+
125+
# Generate outcome
126+
f_XW = outcome_mean(X, W)
127+
epsilon = rng.normal(0, 1, n)
128+
y = f_XW + epsilon
129+
130+
# Test-train split
131+
sample_inds = np.arange(n)
132+
train_inds, test_inds = train_test_split(
133+
sample_inds, test_size=0.2, random_state=random_seed
134+
)
135+
X_train = X[train_inds, :]
136+
X_test = X[test_inds, :]
137+
basis_train = W[train_inds, :]
138+
basis_test = W[test_inds, :]
139+
y_train = y[train_inds]
140+
y_test = y[test_inds]
141+
142+
# Run the parallel BART
143+
combined_bart, combined_bart_preds = bart_warmstart_parallel(
144+
X_train, y_train, basis_train, X_test, basis_test
145+
)
146+
147+
# Inspect the model outputs
148+
y_hat_mcmc = combined_bart.predict(X_test, basis_test)
149+
y_avg_mcmc = np.squeeze(y_hat_mcmc).mean(axis=1, keepdims=True)
150+
y_df = pd.DataFrame(
151+
np.concatenate((y_avg_mcmc, np.expand_dims(y_test, axis=1)), axis=1),
152+
columns=["Average BART Predictions", "Outcome"],
153+
)
154+
155+
# Compare first warm-start chain to outcome
156+
sns.scatterplot(data=y_df, x="Average BART Predictions", y="Outcome")
157+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
158+
plt.show()
159+
160+
# Compare cached predictions to deserialized predictions for first chain
161+
chain_index = 0
162+
num_mcmc = 100
163+
offset_index = num_mcmc * chain_index
164+
chain_inds = slice(offset_index, (offset_index + num_mcmc))
165+
chain_1_preds_original = np.squeeze(combined_bart_preds[chain_inds]).mean(
166+
axis=1, keepdims=True
167+
)
168+
chain_1_preds_reloaded = np.squeeze(y_hat_mcmc[chain_inds]).mean(
169+
axis=1, keepdims=True
170+
)
171+
chain_df = pd.DataFrame(
172+
np.concatenate((chain_1_preds_reloaded, chain_1_preds_original), axis=1),
173+
columns=["New Predictions", "Original Predictions"],
174+
)
175+
sns.scatterplot(data=chain_df, x="New Predictions", y="Original Predictions")
176+
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3)))
177+
plt.show()

src/py_stochtree.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,8 @@ class ForestContainerCpp {
325325

326326
void LoadFromJson(JsonCpp& json, std::string forest_label);
327327

328+
void AppendFromJson(JsonCpp& json, std::string forest_label);
329+
328330
std::string DumpJsonString() {
329331
return forest_samples_->DumpJsonString();
330332
}
@@ -1289,6 +1291,7 @@ class RandomEffectsContainerCpp {
12891291
rfx_container_->LoadFromJsonString(json_string);
12901292
}
12911293
void LoadFromJson(JsonCpp& json, std::string rfx_container_label);
1294+
void AppendFromJson(JsonCpp& json, std::string rfx_container_label);
12921295
StochTree::RandomEffectsContainer* GetRandomEffectsContainer() {
12931296
return rfx_container_.get();
12941297
}
@@ -1870,6 +1873,11 @@ void ForestContainerCpp::LoadFromJson(JsonCpp& json, std::string forest_label) {
18701873
forest_samples_->from_json(forest_json);
18711874
}
18721875

1876+
void ForestContainerCpp::AppendFromJson(JsonCpp& json, std::string forest_label) {
1877+
nlohmann::json forest_json = json.SubsetJsonForest(forest_label);
1878+
forest_samples_->append_from_json(forest_json);
1879+
}
1880+
18731881
void ForestContainerCpp::AdjustResidual(ForestDatasetCpp& dataset, ResidualCpp& residual, ForestSamplerCpp& sampler, bool requires_basis, int forest_num, bool add) {
18741882
// Determine whether or not we are adding forest_num to the residuals
18751883
std::function<double(double, double)> op;
@@ -1896,6 +1904,11 @@ void RandomEffectsContainerCpp::LoadFromJson(JsonCpp& json, std::string rfx_cont
18961904
rfx_container_->from_json(rfx_json);
18971905
}
18981906

1907+
void RandomEffectsContainerCpp::AppendFromJson(JsonCpp& json, std::string rfx_container_label) {
1908+
nlohmann::json rfx_json = json.SubsetJsonRFX().at(rfx_container_label);
1909+
rfx_container_->append_from_json(rfx_json);
1910+
}
1911+
18991912
void RandomEffectsContainerCpp::AddSample(RandomEffectsModelCpp& rfx_model) {
19001913
rfx_container_->AddSample(*rfx_model.GetModel());
19011914
}
@@ -2012,6 +2025,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
20122025
.def("SaveToJsonFile", &ForestContainerCpp::SaveToJsonFile)
20132026
.def("LoadFromJsonFile", &ForestContainerCpp::LoadFromJsonFile)
20142027
.def("LoadFromJson", &ForestContainerCpp::LoadFromJson)
2028+
.def("AppendFromJson", &ForestContainerCpp::AppendFromJson)
20152029
.def("DumpJsonString", &ForestContainerCpp::DumpJsonString)
20162030
.def("LoadFromJsonString", &ForestContainerCpp::LoadFromJsonString)
20172031
.def("AddSampleValue", &ForestContainerCpp::AddSampleValue)
@@ -2125,6 +2139,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
21252139
.def("DumpJsonString", &RandomEffectsContainerCpp::DumpJsonString)
21262140
.def("LoadFromJsonString", &RandomEffectsContainerCpp::LoadFromJsonString)
21272141
.def("LoadFromJson", &RandomEffectsContainerCpp::LoadFromJson)
2142+
.def("AppendFromJson", &RandomEffectsContainerCpp::AppendFromJson)
21282143
.def("GetRandomEffectsContainer", &RandomEffectsContainerCpp::GetRandomEffectsContainer);
21292144

21302145
py::class_<RandomEffectsTrackerCpp>(m, "RandomEffectsTrackerCpp")

0 commit comments

Comments
 (0)