Skip to content

Commit c2f6538

Browse files
authored
Merge pull request #99 from StochasticTree/kernel_updates
Simplifying and streamlining the forest kernel interface
2 parents fc62e0c + e4004ed commit c2f6538

22 files changed

+401
-961
lines changed

DESCRIPTION

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@ RoxygenNote: 7.3.1
1717
LinkingTo:
1818
cpp11
1919
Suggests:
20+
doParallel,
21+
foreach,
22+
ggplot2,
2023
knitr,
21-
rmarkdown,
24+
latex2exp,
2225
Matrix,
23-
tgp,
2426
MASS,
2527
mvtnorm,
26-
ggplot2,
27-
latex2exp,
28+
rmarkdown,
2829
testthat (>= 3.0.0),
29-
foreach,
30-
doParallel
30+
tgp
3131
VignetteBuilder: knitr
3232
SystemRequirements: C++17
3333
Imports:

NAMESPACE

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ S3method(predict,bcf)
77
export(bart)
88
export(bcf)
99
export(calibrate_inverse_gamma_error_variance)
10-
export(computeForestKernels)
1110
export(computeForestLeafIndices)
11+
export(computeMaxLeafIndex)
1212
export(convertBARTModelToJson)
1313
export(convertBCFModelToJson)
1414
export(createBARTModelFromCombinedJson)
@@ -26,7 +26,6 @@ export(createForestContainer)
2626
export(createForestCovariates)
2727
export(createForestCovariatesFromMetadata)
2828
export(createForestDataset)
29-
export(createForestKernel)
3029
export(createForestModel)
3130
export(createOutcome)
3231
export(createRNG)

R/cpp11.R

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -340,32 +340,12 @@ predict_forest_raw_single_forest_cpp <- function(forest_samples, dataset, forest
340340
.Call(`_stochtree_predict_forest_raw_single_forest_cpp`, forest_samples, dataset, forest_num)
341341
}
342342

343-
forest_kernel_cpp <- function() {
344-
.Call(`_stochtree_forest_kernel_cpp`)
343+
forest_container_get_max_leaf_index_cpp <- function(forest_container, forest_num) {
344+
.Call(`_stochtree_forest_container_get_max_leaf_index_cpp`, forest_container, forest_num)
345345
}
346346

347-
forest_kernel_compute_leaf_indices_train_cpp <- function(forest_kernel, covariates_train, forest_container, forest_num) {
348-
invisible(.Call(`_stochtree_forest_kernel_compute_leaf_indices_train_cpp`, forest_kernel, covariates_train, forest_container, forest_num))
349-
}
350-
351-
forest_kernel_compute_leaf_indices_train_test_cpp <- function(forest_kernel, covariates_train, covariates_test, forest_container, forest_num) {
352-
invisible(.Call(`_stochtree_forest_kernel_compute_leaf_indices_train_test_cpp`, forest_kernel, covariates_train, covariates_test, forest_container, forest_num))
353-
}
354-
355-
forest_kernel_get_train_leaf_indices_cpp <- function(forest_kernel) {
356-
.Call(`_stochtree_forest_kernel_get_train_leaf_indices_cpp`, forest_kernel)
357-
}
358-
359-
forest_kernel_get_test_leaf_indices_cpp <- function(forest_kernel) {
360-
.Call(`_stochtree_forest_kernel_get_test_leaf_indices_cpp`, forest_kernel)
361-
}
362-
363-
forest_kernel_compute_kernel_train_cpp <- function(forest_kernel, covariates_train, forest_container, forest_num) {
364-
.Call(`_stochtree_forest_kernel_compute_kernel_train_cpp`, forest_kernel, covariates_train, forest_container, forest_num)
365-
}
366-
367-
forest_kernel_compute_kernel_train_test_cpp <- function(forest_kernel, covariates_train, covariates_test, forest_container, forest_num) {
368-
.Call(`_stochtree_forest_kernel_compute_kernel_train_test_cpp`, forest_kernel, covariates_train, covariates_test, forest_container, forest_num)
347+
compute_leaf_indices_cpp <- function(forest_container, covariates, forest_nums) {
348+
.Call(`_stochtree_compute_leaf_indices_cpp`, forest_container, covariates, forest_nums)
369349
}
370350

371351
sample_gfr_one_iteration_cpp <- function(data, residual, forest_samples, tracker, split_prior, rng, feature_types, cutpoint_grid_size, leaf_model_scale_input, variable_weights, a_forest, b_forest, global_variance, leaf_model_int, pre_initialized) {

R/kernel.R

Lines changed: 141 additions & 229 deletions
Large diffs are not rendered by default.

_pkgdown.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,13 @@ reference:
8181
- createForestModel
8282
- ForestSamples
8383
- createForestContainer
84-
- ForestKernel
85-
- createForestKernel
8684
- CppRNG
8785
- createRNG
8886
- calibrate_inverse_gamma_error_variance
8987
- preprocessBartParams
9088
- preprocessBcfParams
89+
- computeForestLeafIndices
90+
- computeMaxLeafIndex
9191

9292
- subtitle: Random Effects
9393
desc: >
@@ -105,8 +105,6 @@ reference:
105105
- sample_sigma2_one_iteration
106106
- sample_tau_one_iteration
107107
- sample_tau_one_iteration
108-
- computeForestKernels
109-
- computeForestLeafIndices
110108

111109
- title: Package info
112110
desc: >

include/stochtree/container.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ class ForestContainer {
3636
void PredictInPlace(ForestDataset& dataset, std::vector<double>& output);
3737
void PredictRawInPlace(ForestDataset& dataset, std::vector<double>& output);
3838
void PredictRawInPlace(ForestDataset& dataset, int forest_num, std::vector<double>& output);
39+
void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
40+
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
41+
std::vector<int>& forest_indices, int num_trees, data_size_t n);
3942

4043
inline TreeEnsemble* GetEnsemble(int i) {return forests_[i].get();}
4144
inline int32_t NumSamples() {return num_samples_;}

include/stochtree/ensemble.h

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,20 @@ class TreeEnsemble {
241241
}
242242
}
243243

244+
/*!
245+
* \brief Obtain a 0-based "maximum" leaf index for an ensemble, which is equivalent to the sum of the
246+
* number of leaves in each tree. This is used in conjunction with `PredictLeafIndicesInplace`,
247+
* which returns an observation-specific leaf index for every observation-tree pair.
248+
*/
249+
int GetMaxLeafIndex() {
250+
int max_leaf = 0;
251+
for (int j = 0; j < num_trees_; j++) {
252+
auto &tree = *trees_[j];
253+
max_leaf += tree.NumLeaves();
254+
}
255+
return max_leaf;
256+
}
257+
244258
/*!
245259
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
246260
* observation in a ForestDataset. Internally, trees are stored as essentially
@@ -274,7 +288,7 @@ class TreeEnsemble {
274288
*
275289
* Note: this assumes the creation of a vector of column indices of size
276290
* `dataset.NumObservations()` x `ensemble.NumTrees()`
277-
* \param ForestDataset Dataset with which to predict leaf indices from the tree
291+
* \param covariates Matrix of covariates
278292
* \param output Vector of length num_trees*n which stores the leaf node prediction
279293
* \param num_trees Number of trees in an ensemble
280294
* \param n Size of dataset
@@ -292,6 +306,39 @@ class TreeEnsemble {
292306
}
293307
}
294308

309+
/*!
310+
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
311+
* observation in a ForestDataset. Internally, trees are stored as essentially
312+
* vectors of node information, and the leaves_ vector gives us node IDs for every
313+
* leaf in the tree. Here, we would like to know, for every observation in a dataset,
314+
* which leaf number it is mapped to. Since the leaf numbers themselves
315+
* do not carry any information, we renumber them from 0 to `leaves_.size()-1`.
316+
* We compute this at the tree-level and coordinate this computation at the
317+
* ensemble level.
318+
*
319+
* Note: this assumes the creation of a matrix of column indices with `num_trees*n` rows
320+
* and as many columns as forests that were requested from R / Python
321+
* \param covariates Matrix of covariates
322+
* \param output Matrix with num_trees*n rows and as many columns as forests that were requested from R / Python
323+
* \param column_ind Index of column in `output` into which the result should be unpacked
324+
* \param num_trees Number of trees in an ensemble
325+
* \param n Size of dataset
326+
*/
327+
void PredictLeafIndicesInplace(Eigen::Map<Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& covariates,
328+
Eigen::Map<Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic, Eigen::ColMajor>>& output,
329+
int column_ind, int num_trees, data_size_t n) {
330+
CHECK_GE(output.size(), num_trees*n);
331+
int offset = 0;
332+
int max_leaf = 0;
333+
for (int j = 0; j < num_trees; j++) {
334+
auto &tree = *trees_[j];
335+
int num_leaves = tree.NumLeaves();
336+
tree.PredictLeafIndexInplace(covariates, output, column_ind, offset, max_leaf);
337+
offset += n;
338+
max_leaf += num_leaves;
339+
}
340+
}
341+
295342
/*!
296343
* \brief Obtain a 0-based leaf index for every tree in an ensemble and for each
297344
* observation in a ForestDataset. Internally, trees are stored as essentially

0 commit comments

Comments
 (0)