Skip to content

Commit 7bc2188

Browse files
committed
Updated kernel code
1 parent c6d2aa0 commit 7bc2188

File tree

6 files changed

+124
-15
lines changed

6 files changed

+124
-15
lines changed

R/kernel.R

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
206206
#' Compute and return the largest possible leaf index computable by `computeForestLeafIndices` for the forests in a designated forest sample container.
207207
#'
208208
#' @param model_object Object of type `bartmodel`, `bcfmodel`, or `ForestSamples` corresponding to a BART / BCF model with at least one forest sample, or a low-level `ForestSamples` object.
209-
#' @param covariates Covariates to use for prediction. Must have the same dimensions / column types as the data used to train a forest.
210209
#' @param forest_type Which forest to use from `model_object`.
211210
#' Valid inputs depend on the model type, and whether or not a
212211
#'
@@ -238,7 +237,7 @@ computeForestLeafVariances <- function(model_object, forest_type, forest_inds=NU
238237
#' computeForestMaxLeafIndex(bart_model, X, "mean")
239238
#' computeForestMaxLeafIndex(bart_model, X, "mean", 0)
240239
#' computeForestMaxLeafIndex(bart_model, X, "mean", c(1,3,9))
241-
computeForestMaxLeafIndex <- function(model_object, covariates, forest_type=NULL, forest_inds=NULL) {
240+
computeForestMaxLeafIndex <- function(model_object, forest_type=NULL, forest_inds=NULL) {
242241
# Extract relevant forest container
243242
stopifnot(any(c(inherits(model_object, "bartmodel"), inherits(model_object, "bcfmodel"), inherits(model_object, "ForestSamples"))))
244243
model_type <- ifelse(inherits(model_object, "bartmodel"), "bart", ifelse(inherits(model_object, "bcfmodel"), "bcf", "forest_samples"))

src/kernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ typedef Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::Col
1111

1212
[[cpp11::register]]
1313
int forest_container_get_max_leaf_index_cpp(cpp11::external_pointer<StochTree::ForestContainer> forest_container, int forest_num) {
14-
return forest_container->GetEnsemble(forest_num)->GetMaxLeafIndex();
14+
return forest_container->GetEnsemble(forest_num)->GetMaxLeafIndex() - 1;
1515
}
1616

1717
[[cpp11::register]]

src/py_stochtree.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1861,7 +1861,7 @@ py::array_t<int> cppComputeForestContainerLeafIndices(ForestContainerCpp& forest
18611861
}
18621862

18631863
int cppComputeForestMaxLeafIndex(ForestContainerCpp& forest_container, int forest_num) {
1864-
return forest_container.GetForest(forest_num)->GetMaxLeafIndex();
1864+
return forest_container.GetForest(forest_num)->GetMaxLeafIndex() - 1;
18651865
}
18661866

18671867
void ForestContainerCpp::LoadFromJson(JsonCpp& json, std::string forest_label) {

stochtree/__init__.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,22 @@
44
from .config import ForestModelConfig, GlobalModelConfig
55
from .data import Dataset, Residual
66
from .forest import Forest, ForestContainer
7-
from .kernel import compute_forest_leaf_indices
7+
from .kernel import (
8+
compute_forest_leaf_indices,
9+
compute_forest_max_leaf_index
10+
)
811
from .preprocessing import CovariatePreprocessor
912
from .random_effects import (
10-
RandomEffectsContainer,
11-
RandomEffectsDataset,
12-
RandomEffectsModel,
13-
RandomEffectsTracker,
13+
RandomEffectsContainer,
14+
RandomEffectsDataset,
15+
RandomEffectsModel,
16+
RandomEffectsTracker,
1417
)
1518
from .sampler import (
16-
RNG,
17-
ForestSampler,
18-
GlobalVarianceModel,
19-
LeafVarianceModel
19+
RNG,
20+
ForestSampler,
21+
GlobalVarianceModel,
22+
LeafVarianceModel
2023
)
2124
from .serialization import JSONSerializer
2225
from .utils import (
@@ -58,5 +61,6 @@
5861
"_standardize_array_to_list",
5962
"_standardize_array_to_np",
6063
"compute_forest_leaf_indices",
64+
"compute_forest_max_leaf_index",
6165
"calibrate_global_error_variance",
6266
]

stochtree/kernel.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,109 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
9393
num_forests = forest_container.num_samples()
9494
if forest_inds is None:
9595
forest_inds = np.arange(num_forests)
96-
else:
96+
elif isinstance(forest_inds, int):
97+
if not forest_inds >= 0 or not forest_inds < num_forests:
98+
raise ValueError("The index in forest_inds must be >= 0 and < the total number of samples in a forest container")
99+
forest_inds = np.array([forest_inds])
100+
elif isinstance(forest_inds, np.ndarray):
101+
if forest_inds.size > 1:
102+
forest_inds = np.squeeze(forest_inds)
103+
if forest_inds.ndim > 1:
104+
raise ValueError("forest_inds must be a one-dimensional numpy array")
97105
if not np.all(forest_inds >= 0) or not np.all(forest_inds < num_forests):
98106
raise ValueError("The indices in forest_inds must be >= 0 and < the total number of samples in a forest container")
107+
else:
108+
raise ValueError("forest_inds must be a one-dimensional numpy array")
99109

100110
return cppComputeForestContainerLeafIndices(forest_container.forest_container_cpp, covariates_processed, forest_inds)
111+
112+
def compute_forest_max_leaf_index(model_object: Union[BARTModel, BCFModel, ForestContainer], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None):
113+
"""
114+
Compute and return the largest possible leaf index computable by `compute_forest_leaf_indices` for the forests in a designated forest sample container.
115+
116+
Parameters
117+
----------
118+
model_object : BARTModel, BCFModel, or ForestContainer
119+
Object corresponding to a BART / BCF model with at least one forest sample, or a low-level `ForestContainer` object.
120+
forest_type : str
121+
Which forest to use from `model_object`. Valid inputs depend on the model type, and whether or not a given forest was sampled in that model.
122+
123+
* **BART**
124+
* `'mean'`: `'mean'`: Extracts leaf indices for the mean forest
125+
* `'variance'`: Extracts leaf indices for the variance forest
126+
* **BCF**
127+
* `'prognostic'`: Extracts leaf indices for the prognostic forest
128+
* `'treatment'`: Extracts leaf indices for the treatment effect forest
129+
* `'variance'`: Extracts leaf indices for the variance forest
130+
* **ForestContainer**
131+
* `NULL`: It is not necessary to disambiguate when this function is called directly on a `ForestSamples` object. This is the default value of this
132+
133+
forest_inds : int or np.ndarray
134+
Indices of the forest sample(s) for which to compute max leaf indices. If not provided, this function will return max leaf indices for every sample of a forest.
135+
This function uses 0-indexing, so the first forest sample corresponds to `forest_num = 0`, and so on.
136+
137+
Returns
138+
-------
139+
Numpy array containing the largest possible leaf index computable by `compute_forest_leaf_indices` for the forests in a designated forest sample container.
140+
"""
141+
# Extract relevant forest container
142+
if not isinstance(model_object, BARTModel) and not isinstance(model_object, BCFModel) and not isinstance(model_object, ForestContainer):
143+
raise ValueError("model_object must be one of BARTModel, BCFModel, or ForestContainer")
144+
if isinstance(model_object, BARTModel):
145+
model_type = "bart"
146+
if forest_type is None:
147+
raise ValueError("forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')")
148+
elif isinstance(model_object, BCFModel):
149+
model_type = "bcf"
150+
if forest_type is None:
151+
raise ValueError("forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')")
152+
else:
153+
model_type = "forest"
154+
if model_type == "bart":
155+
if forest_type == "mean":
156+
if not model_object.include_mean_forest:
157+
raise ValueError("Mean forest was not sampled for model_object, but requested by forest_type")
158+
forest_container = model_object.forest_container_mean
159+
else:
160+
if not model_object.include_variance_forest:
161+
raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type")
162+
forest_container = model_object.forest_container_variance
163+
elif model_type == "bcf":
164+
if forest_type == "prognostic":
165+
forest_container = model_object.forest_container_mu
166+
elif forest_type == "treatment":
167+
forest_container = model_object.forest_container_tau
168+
else:
169+
if not model_object.include_variance_forest:
170+
raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type")
171+
forest_container = model_object.forest_container_variance
172+
else:
173+
forest_container = model_object
174+
175+
# Preprocess forest indices
176+
num_forests = forest_container.num_samples()
177+
if forest_inds is None:
178+
forest_inds = np.arange(num_forests)
179+
elif isinstance(forest_inds, int):
180+
if not forest_inds >= 0 or not forest_inds < num_forests:
181+
raise ValueError("The index in forest_inds must be >= 0 and < the total number of samples in a forest container")
182+
forest_inds = np.array([forest_inds])
183+
elif isinstance(forest_inds, np.ndarray):
184+
if forest_inds.size > 1:
185+
forest_inds = np.squeeze(forest_inds)
186+
if forest_inds.ndim > 1:
187+
raise ValueError("forest_inds must be a one-dimensional numpy array")
188+
if not np.all(forest_inds >= 0) or not np.all(forest_inds < num_forests):
189+
raise ValueError("The indices in forest_inds must be >= 0 and < the total number of samples in a forest container")
190+
else:
191+
raise ValueError("forest_inds must be a one-dimensional numpy array")
192+
193+
# Compute max index
194+
output_size = len(forest_inds)
195+
output = np.empty(output_size)
196+
for i in np.arange(output_size):
197+
output[i] = cppComputeForestMaxLeafIndex(forest_container.forest_container_cpp, forest_inds[i])
198+
199+
# Return result
200+
return output
201+

test/python/test_kernel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
Dataset,
66
Forest,
77
ForestContainer,
8-
compute_forest_leaf_indices
8+
compute_forest_leaf_indices,
9+
compute_forest_max_leaf_index
910
)
1011

1112

@@ -35,6 +36,7 @@ def test_forest(self):
3536

3637
# Check that regular and "raw" predictions are the same (since the leaf is constant)
3738
computed = compute_forest_leaf_indices(forest_samples, X)
39+
max_leaf_index = compute_forest_max_leaf_index(forest_samples)
3840
expected = np.array([
3941
[0],
4042
[0],
@@ -52,12 +54,14 @@ def test_forest(self):
5254

5355
# Assertion
5456
np.testing.assert_almost_equal(computed, expected)
57+
assert max_leaf_index == [2]
5558

5659
# Split the left leaf of the first tree in the ensemble at X[,2] > 4.0
5760
forest_samples.add_numeric_split(0, 0, 1, 1, 4.0, -7.5, -2.5)
5861

5962
# Check that regular and "raw" predictions are the same (since the leaf is constant)
6063
computed = compute_forest_leaf_indices(forest_samples, X)
64+
max_leaf_index = compute_forest_max_leaf_index(forest_samples)
6165
expected = np.array([
6266
[2],
6367
[1],
@@ -75,3 +79,4 @@ def test_forest(self):
7579

7680
# Assertion
7781
np.testing.assert_almost_equal(computed, expected)
82+
assert max_leaf_index == [3]

0 commit comments

Comments
 (0)