Skip to content

Commit 1ea27d7

Browse files
committed
Adjusted C++ unit tests to reflect updated API
1 parent ddf9ea6 commit 1ea27d7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

test/cpp/test_model.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ TEST(LeafConstantModel, FullEnumeration) {
4444
std::vector<double> cutpoint_values;
4545
std::vector<StochTree::FeatureType> cutpoint_feature_types;
4646
StochTree::data_size_t valid_cutpoint_count = 0;
47+
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
4748
StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);
4849

4950
// Initialize a leaf model
@@ -52,7 +53,7 @@ TEST(LeafConstantModel, FullEnumeration) {
5253
// Evaluate all possible cutpoints
5354
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
5455
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
55-
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
56+
cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types
5657
);
5758

5859
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -109,7 +110,7 @@ TEST(LeafConstantModel, CutpointThinning) {
109110
StochTree::GaussianConstantLeafModel leaf_model = StochTree::GaussianConstantLeafModel(tau);
110111

111112
// Evaluate all possible cutpoints
112-
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
113+
StochTree::<StochTree::GaussianConstantLeafModel, StochTree::GaussianConstantSuffStat>(
113114
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
114115
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
115116
);
@@ -162,6 +163,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
162163
std::vector<double> cutpoint_values;
163164
std::vector<StochTree::FeatureType> cutpoint_feature_types;
164165
StochTree::data_size_t valid_cutpoint_count = 0;
166+
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
165167
StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);
166168

167169
// Initialize a leaf model
@@ -170,7 +172,7 @@ TEST(LeafUnivariateRegressionModel, FullEnumeration) {
170172
// Evaluate all possible cutpoints
171173
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
172174
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
173-
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
175+
cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types
174176
);
175177

176178
// Check that there are (n - 2*min_samples_leaf + 1)*p + 1 cutpoints considered
@@ -222,6 +224,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
222224
std::vector<double> cutpoint_values;
223225
std::vector<StochTree::FeatureType> cutpoint_feature_types;
224226
StochTree::data_size_t valid_cutpoint_count = 0;
227+
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
225228
StochTree::CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);
226229

227230
// Initialize a leaf model
@@ -230,7 +233,7 @@ TEST(LeafUnivariateRegressionModel, CutpointThinning) {
230233
// Evaluate all possible cutpoints
231234
StochTree::EvaluateAllPossibleSplits<StochTree::GaussianUnivariateRegressionLeafModel, StochTree::GaussianUnivariateRegressionSuffStat>(
232235
dataset, tracker, residual, tree_prior, leaf_model, global_variance, 0, 0, log_cutpoint_evaluations, cutpoint_features, cutpoint_values,
233-
cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container, 0, n, variable_weights, feature_types
236+
cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container, 0, n, variable_weights, feature_types
234237
);
235238

236239

0 commit comments

Comments
 (0)