Skip to content

Commit ae13cb0

Browse files
committed
Updated demos and forest initialization python code
1 parent a6caffb commit ae13cb0

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

demo/debug/serialization.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import numpy as np
22
from stochtree import (
33
BARTModel, JSONSerializer, ForestContainer, Dataset, Residual,
4-
RNG, ForestSampler, ForestContainer, GlobalVarianceModel
4+
RNG, ForestSampler, ForestContainer, GlobalVarianceModel,
5+
GlobalModelConfig, ForestModelConfig, Forest
56
)
67

78
# RNG
@@ -53,6 +54,7 @@ def outcome_mean(X, W):
5354
leaf_regression = True
5455
feature_types = np.repeat(0, p_X).astype(int) # 0 = numeric
5556
var_weights = np.repeat(1/p_X, p_X)
57+
leaf_model_type = 1 if p_W == 1 else 2
5658

5759
# Dataset (covariates and basis)
5860
dataset = Dataset()
@@ -64,7 +66,14 @@ def outcome_mean(X, W):
6466

6567
# Forest samplers and temporary tracking data structures
6668
forest_container = ForestContainer(num_trees, W.shape[1], False, False)
67-
forest_sampler = ForestSampler(dataset, feature_types, num_trees, n, alpha, beta, min_samples_leaf)
69+
active_forest = Forest(num_trees, W.shape[1], False, False)
70+
global_config = GlobalModelConfig(global_error_variance=global_variance_init)
71+
forest_config = ForestModelConfig(num_trees=num_trees, num_features=p_X, num_observations=n,
72+
feature_types=feature_types, variable_weights=var_weights,
73+
leaf_dimension=W.shape[1], alpha=alpha, beta=beta,
74+
min_samples_leaf=min_samples_leaf, leaf_model_type=leaf_model_type,
75+
leaf_model_scale=leaf_prior_scale, cutpoint_grid_size=cutpoint_grid_size)
76+
forest_sampler = ForestSampler(dataset, global_config=global_config, forest_config=forest_config)
6877
cpp_rng = RNG(random_seed)
6978
global_var_model = GlobalVarianceModel()
7079

@@ -74,14 +83,18 @@ def outcome_mean(X, W):
7483
num_samples = num_warmstart + num_mcmc
7584
global_var_samples = np.concatenate((np.array([global_variance_init]), np.repeat(0, num_samples)))
7685

86+
# Initialize the forest
87+
constant_leaf_value = np.repeat(0.0, p_W)
88+
active_forest.set_root_leaves(constant_leaf_value)
89+
7790
# Run "grow-from-root" sampler
7891
for i in range(num_warmstart):
79-
forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, True, False)
92+
forest_sampler.sample_one_iteration(forest_container, active_forest, dataset, residual, cpp_rng, global_config, forest_config, True, False)
8093
global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)
8194

8295
# Run MCMC sampler
8396
for i in range(num_warmstart, num_samples):
84-
forest_sampler.sample_one_iteration(forest_container, dataset, residual, cpp_rng, feature_types, cutpoint_grid_size, leaf_prior_scale, var_weights, 1., 1., global_var_samples[i], 1, False, False)
97+
forest_sampler.sample_one_iteration(forest_container, active_forest, dataset, residual, cpp_rng, global_config, forest_config, False, False)
8598
global_var_samples[i+1] = global_var_model.sample_one_iteration(residual, cpp_rng, a_global, b_global)
8699

87100
# Extract predictions from the sampler

stochtree/forest.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -890,10 +890,14 @@ def set_root_leaves(self, leaf_value: Union[float, np.array]) -> None:
890890
if not isinstance(leaf_value, np.ndarray) and not isinstance(leaf_value, float):
891891
raise ValueError("leaf_value must be either a float or np.array")
892892
if isinstance(leaf_value, np.ndarray):
893-
leaf_value = np.squeeze(leaf_value)
894-
if len(leaf_value.shape) != 1:
895-
raise ValueError("leaf_value must be either a one-dimensional array")
896-
self.forest_cpp.SetRootVector(leaf_value, leaf_value.shape[0])
893+
if len(leaf_value.shape) > 1:
894+
leaf_value = np.squeeze(leaf_value)
895+
if len(leaf_value.shape) != 1 or leaf_value.shape[0] != self.output_dimension:
896+
raise ValueError("leaf_value must be a one-dimensional array with dimension equal to the output_dimension field of the forest")
897+
if leaf_value.shape[0] > 1:
898+
self.forest_cpp.SetRootVector(leaf_value, leaf_value.shape[0])
899+
else:
900+
self.forest_cpp.SetRootValue(np.squeeze(leaf_value))
897901
else:
898902
self.forest_cpp.SetRootValue(leaf_value)
899903
self.internal_forest_is_empty = False

0 commit comments

Comments
 (0)