Skip to content

Commit a28af8b

Browse files
authored
Merge pull request #105 from StochasticTree/single-tree-predict
Added ability to predict from a single tree for a given forest
2 parents 225e28f + 021a9b2 commit a28af8b

File tree

5 files changed

+63
-0
lines changed

5 files changed

+63
-0
lines changed

include/stochtree/container.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ class ForestContainer {
3333
std::vector<double> Predict(ForestDataset& dataset);
3434
std::vector<double> PredictRaw(ForestDataset& dataset);
3535
std::vector<double> PredictRaw(ForestDataset& dataset, int forest_num);
36+
std::vector<double> PredictRawSingleTree(ForestDataset& dataset, int forest_num, int tree_num);
3637
void PredictInPlace(ForestDataset& dataset, std::vector<double>& output);
3738
void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output);
3839
void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output);
40+
void PredictRawSingleTreeInPlace(ForestDataset& dataset, int forest_num, int tree_num, std::vector<double>& output);
3941
void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
4042
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
4143
std::vector<int>& forest_indices, int num_trees, data_size_t n);

src/container.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,14 @@ std::vector<double> ForestContainer::PredictRaw(ForestDataset& dataset, int fore
9090
return output;
9191
}
9292

93+
std::vector<double> ForestContainer::PredictRawSingleTree(ForestDataset& dataset, int forest_num, int tree_num) {
94+
data_size_t n = dataset.NumObservations();
95+
data_size_t total_output_size = n * output_dimension_;
96+
std::vector<double> output(total_output_size);
97+
PredictRawSingleTreeInPlace(dataset, forest_num, tree_num, output);
98+
return output;
99+
}
100+
93101
void ForestContainer::PredictInPlace(ForestDataset& dataset, std::vector<double>& output) {
94102
data_size_t n = dataset.NumObservations();
95103
data_size_t total_output_size = n*num_samples_;
@@ -123,6 +131,14 @@ void ForestContainer::PredictRawInPlace(ForestDataset& dataset, int forest_num,
123131
forests_[forest_num]->PredictRawInplace(dataset, output, 0, num_trees, offset);
124132
}
125133

134+
void ForestContainer::PredictRawSingleTreeInPlace(ForestDataset& dataset, int forest_num, int tree_num, std::vector<double>& output) {
135+
data_size_t n = dataset.NumObservations();
136+
data_size_t total_output_size = n * output_dimension_;
137+
CHECK_EQ(total_output_size, output.size());
138+
data_size_t offset = 0;
139+
forests_[forest_num]->PredictRawInplace(dataset, output, tree_num, tree_num+1, offset);
140+
}
141+
126142
void ForestContainer::PredictLeafIndicesInplace(
127143
Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
128144
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,

src/forest.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,3 +422,21 @@ cpp11::writable::doubles_matrix<> predict_forest_raw_single_forest_cpp(cpp11::ex
422422

423423
return output;
424424
}
425+
426+
[[cpp11::register]]
427+
cpp11::writable::doubles_matrix<> predict_forest_raw_single_tree_cpp(cpp11::external_pointer<StochTree::ForestContainer> forest_samples, cpp11::external_pointer<StochTree::ForestDataset> dataset, int forest_num, int tree_num) {
428+
// Predict from the sampled forests
429+
std::vector<double> output_raw = forest_samples->PredictRawSingleTree(*dataset, forest_num, tree_num);
430+
431+
// Convert result to a matrix
432+
int n = dataset->GetCovariates().rows();
433+
int output_dimension = forest_samples->OutputDimension();
434+
cpp11::writable::doubles_matrix<> output(n, output_dimension);
435+
for (size_t i = 0; i < n; i++) {
436+
for (int j = 0; j < output_dimension; j++) {
437+
output(i, j) = output_raw[i*output_dimension + j];
438+
}
439+
}
440+
441+
return output;
442+
}

src/py_stochtree.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,29 @@ class ForestContainerCpp {
235235
return result;
236236
}
237237

238+
py::array_t<double> PredictRawSingleTree(ForestDatasetCpp& dataset, int forest_num, int tree_num) {
239+
// Predict from the forest container
240+
data_size_t n = dataset.NumRows();
241+
int num_samples = this->NumSamples();
242+
int output_dim = this->OutputDimension();
243+
StochTree::ForestDataset* data_ptr = dataset.GetDataset();
244+
std::vector<double> output_raw = forest_samples_->PredictRawSingleTree(*data_ptr, forest_num, tree_num);
245+
246+
// Convert result to a matrix
247+
auto result = py::array_t<double>(py::detail::any_container<py::ssize_t>({n, output_dim}));
248+
auto accessor = result.mutable_unchecked<2>();
249+
// py::buffer_info buf = result.request();
250+
// double *ptr = static_cast<double *>(buf.ptr);
251+
for (size_t i = 0; i < n; i++) {
252+
for (int j = 0; j < output_dim; j++) {
253+
accessor(i,j) = output_raw[i*output_dim + j];
254+
// ptr[i*output_dim + j] = output_raw[i*output_dim + j];
255+
}
256+
}
257+
258+
return result;
259+
}
260+
238261
void SetRootValue(int forest_num, double leaf_value) {
239262
forest_samples_->InitializeRoot(leaf_value);
240263
}

stochtree/forest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ def predict_raw_single_forest(self, dataset: Dataset, forest_num: int) -> np.arr
2828
# Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset
2929
return self.forest_container_cpp.PredictRawSingleForest(dataset.dataset_cpp, forest_num)
3030

31+
def predict_raw_single_tree(self, dataset: Dataset, forest_num: int, tree_num: int) -> np.array:
32+
# Predict raw leaf values for a specific tree from specific forest from Dataset
33+
return self.forest_container_cpp.PredictRawSingleTree(dataset.dataset_cpp, forest_num, tree_num)
34+
3135
def set_root_leaves(self, forest_num: int, leaf_value: Union[float, np.array]) -> None:
3236
# Predict raw leaf values for a specific forest (indexed by forest_num) from Dataset
3337
if not isinstance(leaf_value, np.ndarray) and not isinstance(leaf_value, float):

0 commit comments

Comments
 (0)