Skip to content

Commit eb32b11

Browse files
committed
Added kernel indices module
1 parent ae13cb0 commit eb32b11

File tree

7 files changed

+272
-18
lines changed

7 files changed

+272
-18
lines changed

src/py_stochtree.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ class ForestContainerCpp {
177177
return forest_samples_->OutputDimension();
178178
}
179179

180+
int NumTrees() {
181+
return num_trees_;
182+
}
183+
180184
int NumSamples() {
181185
return forest_samples_->NumSamples();
182186
}
@@ -660,6 +664,10 @@ class ForestCpp {
660664
return forest_->OutputDimension();
661665
}
662666

667+
int NumTrees() {
668+
return num_trees_;
669+
}
670+
663671
int NumLeavesForest() {
664672
return forest_->NumLeaves();
665673
}
@@ -1825,6 +1833,37 @@ class JsonCpp {
18251833
std::unique_ptr<nlohmann::json> json_;
18261834
};
18271835

1836+
py::array_t<int> cppComputeForestContainerLeafIndices(ForestContainerCpp& forest_container, py::array_t<double>& covariates, py::array_t<int>& forest_nums) {
1837+
// Wrap an Eigen Map around the raw data of the covariate matrix
1838+
StochTree::data_size_t num_obs = covariates.shape(0);
1839+
int num_covariates = covariates.shape(1);
1840+
double* covariate_data_ptr = static_cast<double*>(covariates.mutable_data());
1841+
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> covariates_eigen(covariate_data_ptr, num_obs, num_covariates);
1842+
1843+
// Extract other output dimensions
1844+
int num_trees = forest_container.NumTrees();
1845+
int num_samples = forest_nums.size();
1846+
1847+
// Convert forest_nums to std::vector
1848+
std::vector<int> forest_indices(num_samples);
1849+
for (int i = 0; i < num_samples; i++) {
1850+
forest_indices[i] = forest_nums.at(i);
1851+
}
1852+
1853+
// Compute leaf indices
1854+
auto result = py::array_t<int, py::array::f_style>(py::detail::any_container<py::ssize_t>({num_obs*num_trees, num_samples}));
1855+
int* output_data_ptr = static_cast<int*>(result.mutable_data());
1856+
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>> output_eigen(output_data_ptr, num_obs*num_trees, num_samples);
1857+
forest_container.GetContainer()->PredictLeafIndicesInplace(covariates_eigen, output_eigen, forest_indices, num_trees, num_obs);
1858+
1859+
// Return matrix
1860+
return result;
1861+
}
1862+
1863+
int cppComputeForestMaxLeafIndex(ForestContainerCpp& forest_container, int forest_num) {
1864+
return forest_container.GetForest(forest_num)->GetMaxLeafIndex();
1865+
}
1866+
18281867
void ForestContainerCpp::LoadFromJson(JsonCpp& json, std::string forest_label) {
18291868
nlohmann::json forest_json = json.SubsetJsonForest(forest_label);
18301869
forest_samples_->Reset();
@@ -1891,6 +1930,9 @@ void RandomEffectsModelCpp::SampleRandomEffects(RandomEffectsDatasetCpp& rfx_dat
18911930
}
18921931

18931932
PYBIND11_MODULE(stochtree_cpp, m) {
1933+
m.def("cppComputeForestContainerLeafIndices", &cppComputeForestContainerLeafIndices, "Compute leaf indices of the forests in a forest container");
1934+
m.def("cppComputeForestMaxLeafIndex", &cppComputeForestMaxLeafIndex, "Compute max leaf index of a forest in a forest container");
1935+
18941936
py::class_<JsonCpp>(m, "JsonCpp")
18951937
.def(py::init<>())
18961938
.def("LoadFile", &JsonCpp::LoadFile)
@@ -1958,6 +2000,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
19582000
py::class_<ForestContainerCpp>(m, "ForestContainerCpp")
19592001
.def(py::init<int,int,bool,bool>())
19602002
.def("OutputDimension", &ForestContainerCpp::OutputDimension)
2003+
.def("NumTrees", &ForestContainerCpp::NumTrees)
19612004
.def("NumSamples", &ForestContainerCpp::NumSamples)
19622005
.def("DeleteSample", &ForestContainerCpp::DeleteSample)
19632006
.def("Predict", &ForestContainerCpp::Predict)
@@ -2003,6 +2046,7 @@ PYBIND11_MODULE(stochtree_cpp, m) {
20032046
py::class_<ForestCpp>(m, "ForestCpp")
20042047
.def(py::init<int,int,bool,bool>())
20052048
.def("OutputDimension", &ForestCpp::OutputDimension)
2049+
.def("NumTrees", &ForestCpp::NumTrees)
20062050
.def("NumLeavesForest", &ForestCpp::NumLeavesForest)
20072051
.def("SumLeafSquared", &ForestCpp::SumLeafSquared)
20082052
.def("ResetRoot", &ForestCpp::ResetRoot)

stochtree/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .config import ForestModelConfig, GlobalModelConfig
55
from .data import Dataset, Residual
66
from .forest import Forest, ForestContainer
7+
from .kernel import compute_forest_leaf_indices
78
from .preprocessing import CovariatePreprocessor
89
from .random_effects import (
910
RandomEffectsContainer,
@@ -56,5 +57,6 @@
5657
"_check_matrix_square",
5758
"_standardize_array_to_list",
5859
"_standardize_array_to_np",
60+
"compute_forest_leaf_indices",
5961
"calibrate_global_error_variance",
6062
]

stochtree/bart.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,8 +1242,6 @@ def predict(
12421242
)
12431243
covariates_processed = covariates
12441244
else:
1245-
self._covariate_preprocessor = CovariatePreprocessor()
1246-
self._covariate_preprocessor.fit(covariates)
12471245
covariates_processed = self._covariate_preprocessor.transform(covariates)
12481246

12491247
# Dataset construction
@@ -1364,8 +1362,6 @@ def predict_mean(
13641362
)
13651363
covariates_processed = covariates
13661364
else:
1367-
self._covariate_preprocessor = CovariatePreprocessor()
1368-
self._covariate_preprocessor.fit(covariates)
13691365
covariates_processed = self._covariate_preprocessor.transform(covariates)
13701366

13711367
# Dataset construction
@@ -1448,8 +1444,6 @@ def predict_variance(self, covariates: np.array) -> np.array:
14481444
)
14491445
covariates_processed = covariates
14501446
else:
1451-
self._covariate_preprocessor = CovariatePreprocessor()
1452-
self._covariate_preprocessor.fit(covariates)
14531447
covariates_processed = self._covariate_preprocessor.transform(covariates)
14541448

14551449
# Dataset construction

stochtree/bcf.py

Lines changed: 72 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Bayesian Causal Forests (BCF) module
33
"""
44

5+
import warnings
56
from typing import Any, Dict, Optional, Union
67

78
import numpy as np
@@ -1958,11 +1959,32 @@ def predict_tau(
19581959
propensity = np.ones(X.shape[0])
19591960
propensity = np.expand_dims(propensity, 1)
19601961

1962+
# Covariate preprocessing
1963+
if not self._covariate_preprocessor._check_is_fitted():
1964+
if not isinstance(X, np.ndarray):
1965+
raise ValueError(
1966+
"Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
1967+
)
1968+
else:
1969+
warnings.warn(
1970+
"This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.",
1971+
RuntimeWarning,
1972+
)
1973+
if not np.issubdtype(
1974+
X.dtype, np.floating
1975+
) and not np.issubdtype(X.dtype, np.integer):
1976+
raise ValueError(
1977+
"Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
1978+
)
1979+
covariates_processed = X
1980+
else:
1981+
covariates_processed = self._covariate_preprocessor.transform(X)
1982+
19611983
# Update covariates to include propensities if requested
19621984
if self.propensity_covariate == "none":
1963-
X_combined = X
1985+
X_combined = covariates_processed
19641986
else:
1965-
X_combined = np.c_[X, propensity]
1987+
X_combined = np.c_[covariates_processed, propensity]
19661988

19671989
# Forest dataset
19681990
forest_dataset_test = Dataset()
@@ -2022,17 +2044,38 @@ def predict_variance(
20222044
if propensity.ndim == 1:
20232045
propensity = np.expand_dims(propensity, 1)
20242046

2047+
# Covariate preprocessing
2048+
if not self._covariate_preprocessor._check_is_fitted():
2049+
if not isinstance(covariates, np.ndarray):
2050+
raise ValueError(
2051+
"Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
2052+
)
2053+
else:
2054+
warnings.warn(
2055+
"This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.",
2056+
RuntimeWarning,
2057+
)
2058+
if not np.issubdtype(
2059+
covariates.dtype, np.floating
2060+
) and not np.issubdtype(covariates.dtype, np.integer):
2061+
raise ValueError(
2062+
"Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
2063+
)
2064+
covariates_processed = covariates
2065+
else:
2066+
covariates_processed = self._covariate_preprocessor.transform(covariates)
2067+
20252068
# Update covariates to include propensities if requested
20262069
if self.propensity_covariate == "none":
2027-
X_combined = covariates
2070+
X_combined = covariates_processed
20282071
else:
20292072
if propensity is not None:
2030-
X_combined = np.c_[covariates, propensity]
2073+
X_combined = np.c_[covariates_processed, propensity]
20312074
else:
20322075
# Dummy propensities if not provided but also not needed
2033-
propensity = np.ones(covariates.shape[0])
2076+
propensity = np.ones(covariates_processed.shape[0])
20342077
propensity = np.expand_dims(propensity, 1)
2035-
X_combined = np.c_[covariates, propensity]
2078+
X_combined = np.c_[covariates_processed, propensity]
20362079

20372080
# Forest dataset
20382081
pred_dataset = Dataset()
@@ -2124,12 +2167,33 @@ def predict(self, X: np.array, Z: np.array, propensity: np.array = None, rfx_gro
21242167
propensity = np.mean(
21252168
self.bart_propensity_model.predict(X), axis=1, keepdims=True
21262169
)
2170+
2171+
# Covariate preprocessing
2172+
if not self._covariate_preprocessor._check_is_fitted():
2173+
if not isinstance(X, np.ndarray):
2174+
raise ValueError(
2175+
"Prediction cannot proceed on a pandas dataframe, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing covariate data as a Pandas dataframe."
2176+
)
2177+
else:
2178+
warnings.warn(
2179+
"This BCF model has not run any covariate preprocessing routines. We will attempt to predict on the raw covariate values, but this will trigger an error with non-numeric columns. Please refit your model by passing non-numeric covariate data a a Pandas dataframe.",
2180+
RuntimeWarning,
2181+
)
2182+
if not np.issubdtype(
2183+
X.dtype, np.floating
2184+
) and not np.issubdtype(X.dtype, np.integer):
2185+
raise ValueError(
2186+
"Prediction cannot proceed on a non-numeric numpy array, since the BCF model was not fit with a covariate preprocessor. Please refit your model by passing non-numeric covariate data as a Pandas dataframe."
2187+
)
2188+
covariates_processed = X
2189+
else:
2190+
covariates_processed = self._covariate_preprocessor.transform(X)
21272191

21282192
# Update covariates to include propensities if requested
21292193
if self.propensity_covariate == "none":
2130-
X_combined = X
2194+
X_combined = covariates_processed
21312195
else:
2132-
X_combined = np.c_[X, propensity]
2196+
X_combined = np.c_[covariates_processed, propensity]
21332197

21342198
# Forest dataset
21352199
forest_dataset_test = Dataset()

stochtree/forest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,17 @@ def node_leaf_values(
662662
"""
663663
return self.forest_container_cpp.NodeLeafValues(forest_num, tree_num, node_id)
664664

665+
def num_samples(self) -> int:
666+
"""
667+
Number of forest samples in the ``ForestContainer``.
668+
669+
Returns
670+
-------
671+
int
672+
Total number of forest samples.
673+
"""
674+
return self.forest_container_cpp.NumSamples()
675+
665676
def num_nodes(self, forest_num: int, tree_num: int) -> int:
666677
"""
667678
Number of nodes in a given tree in a given forest in the ``ForestContainer``.

stochtree/kernel.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22

33
import pandas as pd
44
import numpy as np
5-
from stochtree import BARTModel, BCFModel, ForestContainer
5+
from stochtree_cpp import cppComputeForestContainerLeafIndices, cppComputeForestMaxLeafIndex
66

7-
from .data import Residual
8-
from .sampler import RNG
7+
from .bart import BARTModel
8+
from .bcf import BCFModel
9+
from .forest import ForestContainer
910

1011

1112
def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestContainer], covariates: Union[np.array, pd.DataFrame], forest_type: str = None, forest_inds: Union[int, np.ndarray] = None):
@@ -44,4 +45,56 @@ def compute_forest_leaf_indices(model_object: Union[BARTModel, BCFModel, ForestC
4445
-------
4546
Numpy array with dimensions `num_obs` by `num_trees`, where `num_obs` is the number of rows in `covaritates` and `num_trees` is the number of trees in the relevant forest of `model_object`.
4647
"""
47-
pass
48+
# Extract relevant forest container
49+
if not isinstance(model_object, BARTModel) and not isinstance(model_object, BCFModel) and not isinstance(model_object, ForestContainer):
50+
raise ValueError("model_object must be one of BARTModel, BCFModel, or ForestContainer")
51+
if isinstance(model_object, BARTModel):
52+
model_type = "bart"
53+
if forest_type is None:
54+
raise ValueError("forest_type must be specified for a BARTModel model_type (either set to 'mean' or 'variance')")
55+
elif isinstance(model_object, BCFModel):
56+
model_type = "bcf"
57+
if forest_type is None:
58+
raise ValueError("forest_type must be specified for a BCFModel model_type (either set to 'prognostic', 'treatment' or 'variance')")
59+
else:
60+
model_type = "forest"
61+
if model_type == "bart":
62+
if forest_type == "mean":
63+
if not model_object.include_mean_forest:
64+
raise ValueError("Mean forest was not sampled for model_object, but requested by forest_type")
65+
forest_container = model_object.forest_container_mean
66+
else:
67+
if not model_object.include_variance_forest:
68+
raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type")
69+
forest_container = model_object.forest_container_variance
70+
elif model_type == "bcf":
71+
if forest_type == "prognostic":
72+
forest_container = model_object.forest_container_mu
73+
elif forest_type == "treatment":
74+
forest_container = model_object.forest_container_tau
75+
else:
76+
if not model_object.include_variance_forest:
77+
raise ValueError("Variance forest was not sampled for model_object, but requested by forest_type")
78+
forest_container = model_object.forest_container_variance
79+
else:
80+
forest_container = model_object
81+
82+
if not isinstance(covariates, pd.DataFrame) and not isinstance(covariates, np.ndarray):
83+
raise ValueError("covariates must be a matrix or dataframe")
84+
85+
# Preprocess covariates
86+
if model_type == "bart" or model_type == "bcf":
87+
covariates_processed = model_object._covariate_preprocessor.transform(covariates)
88+
else:
89+
covariates_processed = covariates
90+
covariates_processed = np.asfortranarray(covariates_processed)
91+
92+
# Preprocess forest indices
93+
num_forests = forest_container.num_samples()
94+
if forest_inds is None:
95+
forest_inds = np.arange(num_forests)
96+
else:
97+
if not np.all(forest_inds >= 0) or not np.all(forest_inds < num_forests):
98+
raise ValueError("The indices in forest_inds must be >= 0 and < the total number of samples in a forest container")
99+
100+
return cppComputeForestContainerLeafIndices(forest_container.forest_container_cpp, covariates_processed, forest_inds)

0 commit comments

Comments
 (0)