@@ -471,8 +471,8 @@ template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatCon
471
471
static inline void EvaluateAllPossibleSplits (
472
472
ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id,
473
473
std::vector<double >& log_cutpoint_evaluations, std::vector<int >& cutpoint_features, std::vector<double >& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
474
- data_size_t & valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector< double >& variable_weights ,
475
- std::vector<FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
474
+ data_size_t & valid_cutpoint_count, std::vector< data_size_t >& feature_cutpoint_counts, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end,
475
+ std::vector<double >& variable_weights, std::vector< FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
476
476
) {
477
477
// Initialize sufficient statistics
478
478
LeafSuffStat node_suff_stat = LeafSuffStat (leaf_suff_stat_args...);
@@ -496,6 +496,7 @@ static inline void EvaluateAllPossibleSplits(
496
496
int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf ();
497
497
498
498
// Compute sufficient statistics for each possible split
499
+ data_size_t feature_cutpoints;
499
500
data_size_t num_cutpoints = 0 ;
500
501
bool valid_split = false ;
501
502
data_size_t node_row_iter;
@@ -509,6 +510,8 @@ static inline void EvaluateAllPossibleSplits(
509
510
double log_split_eval = 0.0 ;
510
511
double split_log_ml;
511
512
for (int j = 0 ; j < covariates.cols (); j++) {
513
+ // Reset feature cutpoint counter
514
+ feature_cutpoints = 0 ;
512
515
513
516
if (std::abs (variable_weights.at (j)) > kEpsilon ) {
514
517
// Enumerate cutpoint strides
@@ -542,6 +545,7 @@ static inline void EvaluateAllPossibleSplits(
542
545
valid_split = (left_suff_stat.SampleGreaterThanEqual (min_samples_in_leaf) &&
543
546
right_suff_stat.SampleGreaterThanEqual (min_samples_in_leaf));
544
547
if (valid_split) {
548
+ feature_cutpoints++;
545
549
num_cutpoints++;
546
550
// Add to split rule vector
547
551
cutpoint_feature_types.push_back (feature_type);
@@ -553,7 +557,8 @@ static inline void EvaluateAllPossibleSplits(
553
557
}
554
558
}
555
559
}
556
-
560
+ // Add feature_cutpoints to feature_cutpoint_counts
561
+ feature_cutpoint_counts.push_back (feature_cutpoints);
557
562
}
558
563
559
564
// Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper)
@@ -570,16 +575,38 @@ template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatCon
570
575
static inline void EvaluateCutpoints (Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior,
571
576
std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end,
572
577
std::vector<double >& log_cutpoint_evaluations, std::vector<int >& cutpoint_features, std::vector<double >& cutpoint_values,
573
- std::vector<FeatureType>& cutpoint_feature_types, data_size_t & valid_cutpoint_count, std::vector<double >& variable_weights,
574
- std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
578
+ std::vector<FeatureType>& cutpoint_feature_types, data_size_t & valid_cutpoint_count, std::vector<StochTree::data_size_t >& feature_cutpoint_counts,
579
+ std::vector<double >& variable_weights, std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container,
580
+ LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
575
581
// Evaluate all possible cutpoints according to the leaf node model,
576
582
// recording their log-likelihood and other split information in a series of vectors.
577
583
// The last element of these vectors concerns the "no-split" option.
578
584
EvaluateAllPossibleSplits<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
579
585
dataset, tracker, residual, tree_prior, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations,
580
- cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container,
586
+ cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container,
581
587
node_begin, node_end, variable_weights, feature_types, leaf_suff_stat_args...
582
588
);
589
+
590
+ // Compute weighting adjustments for low-cardinality categorical features
591
+ // Check if the dataset has continuous features, ignore this adjustment if not
592
+ bool has_continuous_features = false ;
593
+ int max_feature_cutpoint_count = 0 ;
594
+ for (int j = 0 ; j < feature_types.size (); j++) {
595
+ if (feature_types.at (j) == FeatureType::kNumeric ) {
596
+ has_continuous_features = true ;
597
+ if (feature_cutpoint_counts[j] > max_feature_cutpoint_count) max_feature_cutpoint_count = feature_cutpoint_counts[j];
598
+ }
599
+ }
600
+ if (has_continuous_features) {
601
+ double feature_weight;
602
+ for (data_size_t i = 0 ; i < valid_cutpoint_count; i++) {
603
+ // Determine whether the feature is categorical (and thus needs to be re-weighted)
604
+ if ((cutpoint_feature_types[i] == FeatureType::kOrderedCategorical ) || (cutpoint_feature_types[i] == FeatureType::kUnorderedCategorical )) {
605
+ feature_weight = ((double ) max_feature_cutpoint_count) / ((double ) feature_cutpoint_counts[cutpoint_features[i]]);
606
+ log_cutpoint_evaluations[i] += std::log (feature_weight);
607
+ }
608
+ }
609
+ }
583
610
584
611
// Compute an adjustment to reflect the no split prior probability and the number of cutpoints
585
612
double bart_prior_no_split_adj;
@@ -614,12 +641,13 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel
614
641
std::vector<double > cutpoint_values;
615
642
std::vector<FeatureType> cutpoint_feature_types;
616
643
StochTree::data_size_t valid_cutpoint_count;
644
+ std::vector<StochTree::data_size_t > feature_cutpoint_counts;
617
645
CutpointGridContainer cutpoint_grid_container (dataset.GetCovariates (), residual.GetData (), cutpoint_grid_size);
618
646
EvaluateCutpoints<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
619
647
tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance,
620
648
cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features,
621
- cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types ,
622
- cutpoint_grid_container, leaf_suff_stat_args...
649
+ cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, variable_weights ,
650
+ feature_types, cutpoint_grid_container, leaf_suff_stat_args...
623
651
);
624
652
// TODO: maybe add some checks here?
625
653
0 commit comments