|
| 1 | +# Multi Chain Demo Script |
| 2 | + |
| 3 | +# Load necessary libraries |
| 4 | +import matplotlib.pyplot as plt |
| 5 | +import numpy as np |
| 6 | +import pandas as pd |
| 7 | +import seaborn as sns |
| 8 | +from sklearn.model_selection import train_test_split |
| 9 | + |
| 10 | +from stochtree import BARTModel |
| 11 | + |
| 12 | +# Generate sample data |
| 13 | +# RNG |
| 14 | +random_seed = 1234 |
| 15 | +rng = np.random.default_rng(random_seed) |
| 16 | + |
| 17 | +# Generate covariates and basis |
| 18 | +n = 500 |
| 19 | +p_X = 10 |
| 20 | +p_W = 1 |
| 21 | +X = rng.uniform(0, 1, (n, p_X)) |
| 22 | +W = rng.uniform(0, 1, (n, p_W)) |
| 23 | + |
| 24 | +# Define the outcome mean function |
| 25 | +def outcome_mean(X, W): |
| 26 | + return np.where( |
| 27 | + (X[:, 0] >= 0.0) & (X[:, 0] < 0.25), |
| 28 | + -7.5 * W[:, 0], |
| 29 | + np.where( |
| 30 | + (X[:, 0] >= 0.25) & (X[:, 0] < 0.5), |
| 31 | + -2.5 * W[:, 0], |
| 32 | + np.where((X[:, 0] >= 0.5) & (X[:, 0] < 0.75), 2.5 * W[:, 0], 7.5 * W[:, 0]), |
| 33 | + ), |
| 34 | + ) |
| 35 | + |
| 36 | +# Generate outcome |
| 37 | +f_XW = outcome_mean(X, W) |
| 38 | +epsilon = rng.normal(0, 1, n) |
| 39 | +y = f_XW + epsilon |
| 40 | + |
| 41 | +# Test-train split |
| 42 | +sample_inds = np.arange(n) |
| 43 | +train_inds, test_inds = train_test_split(sample_inds, test_size=0.5, random_state=random_seed) |
| 44 | +X_train = X[train_inds, :] |
| 45 | +X_test = X[test_inds, :] |
| 46 | +basis_train = W[train_inds, :] |
| 47 | +basis_test = W[test_inds, :] |
| 48 | +y_train = y[train_inds] |
| 49 | +y_test = y[test_inds] |
| 50 | + |
| 51 | +# Run the GFR algorithm for a small number of iterations |
| 52 | +general_model_params = {"random_seed": -1} |
| 53 | +mean_forest_model_params = {"num_trees": 20} |
| 54 | +num_warmstart = 10 |
| 55 | +num_mcmc = 10 |
| 56 | +bart_model = BARTModel() |
| 57 | +bart_model.sample( |
| 58 | + X_train=X_train, |
| 59 | + y_train=y_train, |
| 60 | + leaf_basis_train=basis_train, |
| 61 | + X_test=X_test, |
| 62 | + leaf_basis_test=basis_test, |
| 63 | + num_gfr=num_warmstart, |
| 64 | + num_mcmc=0, |
| 65 | + general_params=general_model_params, |
| 66 | + mean_forest_params=mean_forest_model_params |
| 67 | +) |
| 68 | +bart_model_json = bart_model.to_json() |
| 69 | + |
| 70 | +# Run several BART MCMC samples from the last GFR forest |
| 71 | +bart_model_2 = BARTModel() |
| 72 | +bart_model_2.sample( |
| 73 | + X_train=X_train, |
| 74 | + y_train=y_train, |
| 75 | + leaf_basis_train=basis_train, |
| 76 | + X_test=X_test, |
| 77 | + leaf_basis_test=basis_test, |
| 78 | + num_gfr=0, |
| 79 | + num_mcmc=num_mcmc, |
| 80 | + previous_model_json=bart_model_json, |
| 81 | + previous_model_warmstart_sample_num=num_warmstart-1, |
| 82 | + general_params=general_model_params, |
| 83 | + mean_forest_params=mean_forest_model_params |
| 84 | +) |
| 85 | + |
| 86 | +# Run several BART MCMC samples from the second-to-last GFR forest |
| 87 | +bart_model_3 = BARTModel() |
| 88 | +bart_model_3.sample( |
| 89 | + X_train=X_train, |
| 90 | + y_train=y_train, |
| 91 | + leaf_basis_train=basis_train, |
| 92 | + X_test=X_test, |
| 93 | + leaf_basis_test=basis_test, |
| 94 | + num_gfr=0, |
| 95 | + num_mcmc=num_mcmc, |
| 96 | + previous_model_json=bart_model_json, |
| 97 | + previous_model_warmstart_sample_num=num_warmstart-2, |
| 98 | + general_params=general_model_params, |
| 99 | + mean_forest_params=mean_forest_model_params |
| 100 | +) |
| 101 | + |
| 102 | +# Run several BART MCMC samples from root |
| 103 | +bart_model_4 = BARTModel() |
| 104 | +bart_model_4.sample( |
| 105 | + X_train=X_train, |
| 106 | + y_train=y_train, |
| 107 | + leaf_basis_train=basis_train, |
| 108 | + X_test=X_test, |
| 109 | + leaf_basis_test=basis_test, |
| 110 | + num_gfr=0, |
| 111 | + num_mcmc=num_mcmc, |
| 112 | + general_params=general_model_params, |
| 113 | + mean_forest_params=mean_forest_model_params |
| 114 | +) |
| 115 | + |
| 116 | +# Inspect the model outputs |
| 117 | +y_hat_mcmc_2 = bart_model_2.predict(X_test, basis_test) |
| 118 | +y_avg_mcmc_2 = np.squeeze(y_hat_mcmc_2).mean(axis=1, keepdims=True) |
| 119 | +y_hat_mcmc_3 = bart_model_3.predict(X_test, basis_test) |
| 120 | +y_avg_mcmc_3 = np.squeeze(y_hat_mcmc_3).mean(axis=1, keepdims=True) |
| 121 | +y_hat_mcmc_4 = bart_model_4.predict(X_test, basis_test) |
| 122 | +y_avg_mcmc_4 = np.squeeze(y_hat_mcmc_4).mean(axis=1, keepdims=True) |
| 123 | +y_df = pd.DataFrame( |
| 124 | + np.concatenate((y_avg_mcmc_2, y_avg_mcmc_3, y_avg_mcmc_4, np.expand_dims(y_test, axis=1)), axis=1), |
| 125 | + columns=["First Chain", "Second Chain", "Third Chain", "Outcome"], |
| 126 | +) |
| 127 | + |
| 128 | +# Compare first warm-start chain to root chain with equal number of MCMC draws |
| 129 | +sns.scatterplot(data=y_df, x="First Chain", y="Third Chain") |
| 130 | +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) |
| 131 | +plt.show() |
| 132 | + |
| 133 | +# Compare first warm-start chain to outcome |
| 134 | +sns.scatterplot(data=y_df, x="First Chain", y="Outcome") |
| 135 | +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) |
| 136 | +plt.show() |
| 137 | + |
| 138 | +# Compare root chain to outcome |
| 139 | +sns.scatterplot(data=y_df, x="Third Chain", y="Outcome") |
| 140 | +plt.axline((0, 0), slope=1, color="black", linestyle=(0, (3, 3))) |
| 141 | +plt.show() |
| 142 | + |
| 143 | +# Compute RMSEs |
| 144 | +rmse_1 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_2)-y_test)*(np.squeeze(y_avg_mcmc_2)-y_test))) |
| 145 | +rmse_2 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_3)-y_test)*(np.squeeze(y_avg_mcmc_3)-y_test))) |
| 146 | +rmse_3 = np.sqrt(np.mean((np.squeeze(y_avg_mcmc_4)-y_test)*(np.squeeze(y_avg_mcmc_4)-y_test))) |
| 147 | +print("Chain 1 rmse: {:0.3f}; Chain 2 rmse: {:0.3f}; Chain 3 rmse: {:0.3f}".format(rmse_1, rmse_2, rmse_3)) |
0 commit comments