@@ -44,6 +44,7 @@ TEST(LeafConstantModel, FullEnumeration) {
44
44
std::vector<double > cutpoint_values;
45
45
std::vector<StochTree::FeatureType> cutpoint_feature_types;
46
46
StochTree::data_size_t valid_cutpoint_count = 0 ;
47
+ std::vector<StochTree::data_size_t > feature_cutpoint_counts;
47
48
StochTree::CutpointGridContainer cutpoint_grid_container (dataset.GetCovariates (), residual.GetData (), cutpoint_grid_size);
48
49
49
50
// Initialize a leaf model
@@ -52,7 +53,7 @@ TEST(LeafConstantModel, FullEnumeration) {
52
53
// Evaluate all possible cutpoints
53
54
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
54
55
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
55
- cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
56
+ cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0 , n, variable_weights, feature_types
56
57
);
57
58
58
59
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -109,7 +110,7 @@ TEST(LeafConstantModel, CutpointThinning) {
109
110
StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel (tau);
110
111
111
112
// Evaluate all possible cutpoints
112
- StochTree::EvaluateAllPossibleSplits <StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
113
+ StochTree::<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
113
114
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
114
115
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
115
116
);
@@ -162,6 +163,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
162
163
std::vector<double > cutpoint_values;
163
164
std::vector<StochTree::FeatureType> cutpoint_feature_types;
164
165
StochTree::data_size_t valid_cutpoint_count = 0 ;
166
+ std::vector<StochTree::data_size_t > feature_cutpoint_counts;
165
167
StochTree::CutpointGridContainer cutpoint_grid_container (dataset.GetCovariates (), residual.GetData (), cutpoint_grid_size);
166
168
167
169
// Initialize a leaf model
@@ -170,7 +172,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
170
172
// Evaluate all possible cutpoints
171
173
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
172
174
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
173
- cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
175
+ cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0 , n, variable_weights, feature_types
174
176
);
175
177
176
178
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -222,6 +224,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
222
224
std::vector<double > cutpoint_values;
223
225
std::vector<StochTree::FeatureType> cutpoint_feature_types;
224
226
StochTree::data_size_t valid_cutpoint_count = 0 ;
227
+ std::vector<StochTree::data_size_t > feature_cutpoint_counts;
225
228
StochTree::CutpointGridContainer cutpoint_grid_container (dataset.GetCovariates (), residual.GetData (), cutpoint_grid_size);
226
229
227
230
// Initialize a leaf model
@@ -230,7 +233,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
230
233
// Evaluate all possible cutpoints
231
234
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
232
235
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0 , 0 , log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
233
- cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0 , n, variable_weights, feature_types
236
+ cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0 , n, variable_weights, feature_types
234
237
);
235
238
236
239
0 commit comments