1
1
import numpy as np
2
2
from stochtree import (
3
3
BARTModel , JSONSerializer , ForestContainer , Dataset , Residual ,
4
- RNG , ForestSampler , ForestContainer , GlobalVarianceModel
4
+ RNG , ForestSampler , ForestContainer , GlobalVarianceModel ,
5
+ GlobalModelConfig , ForestModelConfig , Forest
5
6
)
6
7
7
8
# RNG
@@ -53,6 +54,7 @@ def outcome_mean(X, W):
53
54
leaf_regression = True
54
55
feature_types = np .repeat (0 , p_X ).astype (int ) # 0 = numeric
55
56
var_weights = np .repeat (1 / p_X , p_X )
57
+ leaf_model_type = 1 if p_W == 1 else 2
56
58
57
59
# Dataset (covariates and basis)
58
60
dataset = Dataset ()
@@ -64,7 +66,14 @@ def outcome_mean(X, W):
64
66
65
67
# Forest samplers and temporary tracking data structures
66
68
forest_container = ForestContainer (num_trees , W .shape [1 ], False , False )
67
- forest_sampler = ForestSampler (dataset , feature_types , num_trees , n , alpha , beta , min_samples_leaf )
69
+ active_forest = Forest (num_trees , W .shape [1 ], False , False )
70
+ global_config = GlobalModelConfig (global_error_variance = global_variance_init )
71
+ forest_config = ForestModelConfig (num_trees = num_trees , num_features = p_X , num_observations = n ,
72
+ feature_types = feature_types , variable_weights = var_weights ,
73
+ leaf_dimension = W .shape [1 ], alpha = alpha , beta = beta ,
74
+ min_samples_leaf = min_samples_leaf , leaf_model_type = leaf_model_type ,
75
+ leaf_model_scale = leaf_prior_scale , cutpoint_grid_size = cutpoint_grid_size )
76
+ forest_sampler = ForestSampler (dataset , global_config = global_config , forest_config = forest_config )
68
77
cpp_rng = RNG (random_seed )
69
78
global_var_model = GlobalVarianceModel ()
70
79
@@ -74,14 +83,18 @@ def outcome_mean(X, W):
74
83
num_samples = num_warmstart + num_mcmc
75
84
global_var_samples = np .concatenate ((np .array ([global_variance_init ]), np .repeat (0 , num_samples )))
76
85
86
+ # Initialize the forest
87
+ constant_leaf_value = np .repeat (0.0 , p_W )
88
+ active_forest .set_root_leaves (constant_leaf_value )
89
+
77
90
# Run "grow-from-root" sampler
78
91
for i in range (num_warmstart ):
79
- forest_sampler .sample_one_iteration (forest_container , dataset , residual , cpp_rng , feature_types , cutpoint_grid_size , leaf_prior_scale , var_weights , 1. , 1. , global_var_samples [ i ], 1 , True , False )
92
+ forest_sampler .sample_one_iteration (forest_container , active_forest , dataset , residual , cpp_rng , global_config , forest_config , True , False )
80
93
global_var_samples [i + 1 ] = global_var_model .sample_one_iteration (residual , cpp_rng , a_global , b_global )
81
94
82
95
# Run MCMC sampler
83
96
for i in range (num_warmstart , num_samples ):
84
- forest_sampler .sample_one_iteration (forest_container , dataset , residual , cpp_rng , feature_types , cutpoint_grid_size , leaf_prior_scale , var_weights , 1. , 1. , global_var_samples [ i ], 1 , False , False )
97
+ forest_sampler .sample_one_iteration (forest_container , active_forest , dataset , residual , cpp_rng , global_config , forest_config , False , False )
85
98
global_var_samples [i + 1 ] = global_var_model .sample_one_iteration (residual , cpp_rng , a_global , b_global )
86
99
87
100
# Extract predictions from the sampler
0 commit comments