Skip to content

Commit a6caffb

Browse files
committed
Updated demo scripts
1 parent f5b2c72 commit a6caffb

File tree

2 files changed

+11
-13
lines changed

2 files changed

+11
-13
lines changed

demo/debug/causal_inference.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,38 +63,38 @@
6363

6464
# Run BCF
6565
bcf_model = BCFModel()
66-
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)
66+
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=1000)
6767

6868
# Inspect the MCMC (BART) samples
69-
forest_preds_y_mcmc = bcf_model.y_hat_test[:,bcf_model.num_gfr:]
69+
forest_preds_y_mcmc = bcf_model.y_hat_test
7070
y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)
7171
y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"])
7272
sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
7373
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
7474
plt.show()
7575

76-
forest_preds_tau_mcmc = bcf_model.tau_hat_test[:,bcf_model.num_gfr:]
76+
forest_preds_tau_mcmc = bcf_model.tau_hat_test
7777
tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)
7878
tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test,1), tau_avg_mcmc), axis = 1), columns=["True tau", "Average estimated tau"])
7979
sns.scatterplot(data=tau_df_mcmc, x="Average estimated tau", y="True tau")
8080
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
8181
plt.show()
8282

83-
forest_preds_mu_mcmc = bcf_model.mu_hat_test[:,bcf_model.num_gfr:]
83+
forest_preds_mu_mcmc = bcf_model.mu_hat_test
8484
mu_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis = 1, keepdims = True)
8585
mu_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(mu_test,1), mu_avg_mcmc), axis = 1), columns=["True mu", "Average estimated mu"])
8686
sns.scatterplot(data=mu_df_mcmc, x="Average estimated mu", y="True mu")
8787
plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
8888
plt.show()
8989

90-
# sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
91-
# sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
92-
# plt.show()
90+
sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"])
91+
sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma")
92+
plt.show()
9393

94-
# b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
95-
# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
96-
# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
97-
# plt.show()
94+
b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"])
95+
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0")
96+
sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1")
97+
plt.show()
9898

9999
# Compute RMSEs
100100
y_rmse = np.sqrt(np.mean(np.power(np.expand_dims(y_test,1) - y_avg_mcmc, 2)))

demo/debug/multivariate_treatment_causal_inference.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,3 @@
4444
# Run BCF
4545
bcf_model = BCFModel()
4646
bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100)
47-
48-

0 commit comments

Comments
 (0)