@@ -802,46 +802,55 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM
802
802
double no_split_log_marginal_likelihood = std::get<1 >(split_eval);
803
803
int32_t left_n = std::get<2 >(split_eval);
804
804
int32_t right_n = std::get<3 >(split_eval);
805
-
806
- // Determine probability of growing the split node and its two new left and right nodes
807
- double pg = tree_prior.GetAlpha () * std::pow (1 +leaf_depth, -tree_prior.GetBeta ());
808
- double pgl = tree_prior.GetAlpha () * std::pow (1 +leaf_depth+1 , -tree_prior.GetBeta ());
809
- double pgr = tree_prior.GetAlpha () * std::pow (1 +leaf_depth+1 , -tree_prior.GetBeta ());
810
-
811
- // Determine whether a "grow" move is possible from the newly formed tree
812
- // in order to compute the probability of choosing "prune" from the new tree
813
- // (which is always possible by construction)
814
- bool non_constant = NodesNonConstantAfterSplit (dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
815
- bool min_samples_left_check = left_n >= 2 *tree_prior.GetMinSamplesLeaf ();
816
- bool min_samples_right_check = right_n >= 2 *tree_prior.GetMinSamplesLeaf ();
817
- double prob_prune_new;
818
- if (non_constant && (min_samples_left_check || min_samples_right_check)) {
819
- prob_prune_new = 0.5 ;
820
- } else {
821
- prob_prune_new = 1.0 ;
822
- }
823
805
824
- // Determine the number of leaves in the current tree and leaf parents in the proposed tree
825
- int num_leaf_parents = tree->NumLeafParents ();
826
- double p_leaf = 1 /static_cast <double >(num_leaves);
827
- double p_leaf_parent = 1 /static_cast <double >(num_leaf_parents+1 );
806
+ // Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf()
807
+ bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf ();
808
+ bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf ();
809
+ if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) {
810
+
811
+ // Determine probability of growing the split node and its two new left and right nodes
812
+ double pg = tree_prior.GetAlpha () * std::pow (1 +leaf_depth, -tree_prior.GetBeta ());
813
+ double pgl = tree_prior.GetAlpha () * std::pow (1 +leaf_depth+1 , -tree_prior.GetBeta ());
814
+ double pgr = tree_prior.GetAlpha () * std::pow (1 +leaf_depth+1 , -tree_prior.GetBeta ());
815
+
816
+ // Determine whether a "grow" move is possible from the newly formed tree
817
+ // in order to compute the probability of choosing "prune" from the new tree
818
+ // (which is always possible by construction)
819
+ bool non_constant = NodesNonConstantAfterSplit (dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
820
+ bool min_samples_left_check = left_n >= 2 *tree_prior.GetMinSamplesLeaf ();
821
+ bool min_samples_right_check = right_n >= 2 *tree_prior.GetMinSamplesLeaf ();
822
+ double prob_prune_new;
823
+ if (non_constant && (min_samples_left_check || min_samples_right_check)) {
824
+ prob_prune_new = 0.5 ;
825
+ } else {
826
+ prob_prune_new = 1.0 ;
827
+ }
828
+
829
+ // Determine the number of leaves in the current tree and leaf parents in the proposed tree
830
+ int num_leaf_parents = tree->NumLeafParents ();
831
+ double p_leaf = 1 /static_cast <double >(num_leaves);
832
+ double p_leaf_parent = 1 /static_cast <double >(num_leaf_parents+1 );
833
+
834
+ // Compute the final MH ratio
835
+ double log_mh_ratio = (
836
+ std::log (pg) + std::log (1 -pgl) + std::log (1 -pgr) - std::log (1 -pg) + std::log (prob_prune_new) +
837
+ std::log (p_leaf_parent) - std::log (prob_grow_old) - std::log (p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
838
+ );
839
+ // Threshold at 0
840
+ if (log_mh_ratio > 0 ) {
841
+ log_mh_ratio = 0 ;
842
+ }
828
843
829
- // Compute the final MH ratio
830
- double log_mh_ratio = (
831
- std::log (pg) + std::log ( 1 -pgl) + std::log (1 -pgr) - std::log ( 1 -pg) + std::log (prob_prune_new) +
832
- std::log (p_leaf_parent) - std::log (prob_grow_old) - std::log (p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
833
- ) ;
834
- // Threshold at 0
835
- if (log_mh_ratio > 0 ) {
836
- log_mh_ratio = 0 ;
837
- }
844
+ // Draw a uniform random variable and accept/reject the proposal on this basis
845
+ std::uniform_real_distribution< double > mh_accept ( 0.0 , 1.0 );
846
+ double log_acceptance_prob = std::log (mh_accept (gen));
847
+ if (log_acceptance_prob <= log_mh_ratio) {
848
+ accept = true ;
849
+ AddSplitToModel (tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false );
850
+ } else {
851
+ accept = false ;
852
+ }
838
853
839
- // Draw a uniform random variable and accept/reject the proposal on this basis
840
- std::uniform_real_distribution<double > mh_accept (0.0 , 1.0 );
841
- double log_acceptance_prob = std::log (mh_accept (gen));
842
- if (log_acceptance_prob <= log_mh_ratio) {
843
- accept = true ;
844
- AddSplitToModel (tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false );
845
854
} else {
846
855
accept = false ;
847
856
}
0 commit comments