|
63 | 63 |
|
64 | 64 | # Run BCF
|
65 | 65 | bcf_model = BCFModel()
|
66 |
| -bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=100) |
| 66 | +bcf_model.sample(X_train, Z_train, y_train, pi_train, X_test, Z_test, pi_test, num_gfr=10, num_mcmc=1000) |
67 | 67 |
|
68 | 68 | # Inspect the MCMC (BART) samples
|
69 |
| -forest_preds_y_mcmc = bcf_model.y_hat_test[:,bcf_model.num_gfr:] |
| 69 | +forest_preds_y_mcmc = bcf_model.y_hat_test |
70 | 70 | y_avg_mcmc = np.squeeze(forest_preds_y_mcmc).mean(axis = 1, keepdims = True)
|
71 | 71 | y_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(y_test,1), y_avg_mcmc), axis = 1), columns=["True outcome", "Average estimated outcome"])
|
72 | 72 | sns.scatterplot(data=y_df_mcmc, x="Average estimated outcome", y="True outcome")
|
73 | 73 | plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
|
74 | 74 | plt.show()
|
75 | 75 |
|
76 |
| -forest_preds_tau_mcmc = bcf_model.tau_hat_test[:,bcf_model.num_gfr:] |
| 76 | +forest_preds_tau_mcmc = bcf_model.tau_hat_test |
77 | 77 | tau_avg_mcmc = np.squeeze(forest_preds_tau_mcmc).mean(axis = 1, keepdims = True)
|
78 | 78 | tau_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(tau_test,1), tau_avg_mcmc), axis = 1), columns=["True tau", "Average estimated tau"])
|
79 | 79 | sns.scatterplot(data=tau_df_mcmc, x="Average estimated tau", y="True tau")
|
80 | 80 | plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
|
81 | 81 | plt.show()
|
82 | 82 |
|
83 |
| -forest_preds_mu_mcmc = bcf_model.mu_hat_test[:,bcf_model.num_gfr:] |
| 83 | +forest_preds_mu_mcmc = bcf_model.mu_hat_test |
84 | 84 | mu_avg_mcmc = np.squeeze(forest_preds_mu_mcmc).mean(axis = 1, keepdims = True)
|
85 | 85 | mu_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(mu_test,1), mu_avg_mcmc), axis = 1), columns=["True mu", "Average estimated mu"])
|
86 | 86 | sns.scatterplot(data=mu_df_mcmc, x="Average estimated mu", y="True mu")
|
87 | 87 | plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3,3)))
|
88 | 88 | plt.show()
|
89 | 89 |
|
90 |
| -# sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) |
91 |
| -# sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") |
92 |
| -# plt.show() |
| 90 | +sigma_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.global_var_samples,axis=1)), axis = 1), columns=["Sample", "Sigma"]) |
| 91 | +sns.scatterplot(data=sigma_df_mcmc, x="Sample", y="Sigma") |
| 92 | +plt.show() |
93 | 93 |
|
94 |
| -# b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples - bcf_model.num_gfr),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"]) |
95 |
| -# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0") |
96 |
| -# sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1") |
97 |
| -# plt.show() |
| 94 | +b_df_mcmc = pd.DataFrame(np.concatenate((np.expand_dims(np.arange(bcf_model.num_samples),axis=1), np.expand_dims(bcf_model.b0_samples,axis=1), np.expand_dims(bcf_model.b1_samples,axis=1)), axis = 1), columns=["Sample", "Beta_0", "Beta_1"]) |
| 95 | +sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_0") |
| 96 | +sns.scatterplot(data=b_df_mcmc, x="Sample", y="Beta_1") |
| 97 | +plt.show() |
98 | 98 |
|
99 | 99 | # Compute RMSEs
|
100 | 100 | y_rmse = np.sqrt(np.mean(np.power(np.expand_dims(y_test,1) - y_avg_mcmc, 2)))
|
|
0 commit comments