Skip to content

Commit 8bffc8c

Browse files
authored
Merge pull request #101 from StochasticTree/min_samples_hotfix
Properly enforcing min_samples_leaf cutoff in MCMC sampler
2 parents 5cab0f0 + 8c6d083 commit 8bffc8c

File tree

1 file changed

+46
-37
lines changed

1 file changed

+46
-37
lines changed

include/stochtree/tree_sampler.h

Lines changed: 46 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -802,46 +802,55 @@ static inline void MCMCGrowTreeOneIter(Tree* tree, ForestTracker& tracker, LeafM
802802
double no_split_log_marginal_likelihood = std::get<1>(split_eval);
803803
int32_t left_n = std::get<2>(split_eval);
804804
int32_t right_n = std::get<3>(split_eval);
805-
806-
// Determine probability of growing the split node and its two new left and right nodes
807-
double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
808-
double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
809-
double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
810-
811-
// Determine whether a "grow" move is possible from the newly formed tree
812-
// in order to compute the probability of choosing "prune" from the new tree
813-
// (which is always possible by construction)
814-
bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
815-
bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
816-
bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
817-
double prob_prune_new;
818-
if (non_constant && (min_samples_left_check || min_samples_right_check)) {
819-
prob_prune_new = 0.5;
820-
} else {
821-
prob_prune_new = 1.0;
822-
}
823805

824-
// Determine the number of leaves in the current tree and leaf parents in the proposed tree
825-
int num_leaf_parents = tree->NumLeafParents();
826-
double p_leaf = 1/static_cast<double>(num_leaves);
827-
double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);
806+
// Reject the split if either of the left and right nodes are smaller than tree_prior.GetMinSamplesLeaf()
807+
bool left_node_sample_cutoff = left_n >= tree_prior.GetMinSamplesLeaf();
808+
bool right_node_sample_cutoff = right_n >= tree_prior.GetMinSamplesLeaf();
809+
if ((left_node_sample_cutoff) && (right_node_sample_cutoff)) {
810+
811+
// Determine probability of growing the split node and its two new left and right nodes
812+
double pg = tree_prior.GetAlpha() * std::pow(1+leaf_depth, -tree_prior.GetBeta());
813+
double pgl = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
814+
double pgr = tree_prior.GetAlpha() * std::pow(1+leaf_depth+1, -tree_prior.GetBeta());
815+
816+
// Determine whether a "grow" move is possible from the newly formed tree
817+
// in order to compute the probability of choosing "prune" from the new tree
818+
// (which is always possible by construction)
819+
bool non_constant = NodesNonConstantAfterSplit(dataset, tracker, split, tree_num, leaf_chosen, var_chosen);
820+
bool min_samples_left_check = left_n >= 2*tree_prior.GetMinSamplesLeaf();
821+
bool min_samples_right_check = right_n >= 2*tree_prior.GetMinSamplesLeaf();
822+
double prob_prune_new;
823+
if (non_constant && (min_samples_left_check || min_samples_right_check)) {
824+
prob_prune_new = 0.5;
825+
} else {
826+
prob_prune_new = 1.0;
827+
}
828+
829+
// Determine the number of leaves in the current tree and leaf parents in the proposed tree
830+
int num_leaf_parents = tree->NumLeafParents();
831+
double p_leaf = 1/static_cast<double>(num_leaves);
832+
double p_leaf_parent = 1/static_cast<double>(num_leaf_parents+1);
833+
834+
// Compute the final MH ratio
835+
double log_mh_ratio = (
836+
std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
837+
std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
838+
);
839+
// Threshold at 0
840+
if (log_mh_ratio > 0) {
841+
log_mh_ratio = 0;
842+
}
828843

829-
// Compute the final MH ratio
830-
double log_mh_ratio = (
831-
std::log(pg) + std::log(1-pgl) + std::log(1-pgr) - std::log(1-pg) + std::log(prob_prune_new) +
832-
std::log(p_leaf_parent) - std::log(prob_grow_old) - std::log(p_leaf) - no_split_log_marginal_likelihood + split_log_marginal_likelihood
833-
);
834-
// Threshold at 0
835-
if (log_mh_ratio > 0) {
836-
log_mh_ratio = 0;
837-
}
844+
// Draw a uniform random variable and accept/reject the proposal on this basis
845+
std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
846+
double log_acceptance_prob = std::log(mh_accept(gen));
847+
if (log_acceptance_prob <= log_mh_ratio) {
848+
accept = true;
849+
AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false);
850+
} else {
851+
accept = false;
852+
}
838853

839-
// Draw a uniform random variable and accept/reject the proposal on this basis
840-
std::uniform_real_distribution<double> mh_accept(0.0, 1.0);
841-
double log_acceptance_prob = std::log(mh_accept(gen));
842-
if (log_acceptance_prob <= log_mh_ratio) {
843-
accept = true;
844-
AddSplitToModel(tracker, dataset, tree_prior, split, gen, tree, tree_num, leaf_chosen, var_chosen, false);
845854
} else {
846855
accept = false;
847856
}

0 commit comments

Comments
 (0)