Skip to content

Commit ddf9ea6

Browse files
committed
Initial implementation of categorical feature re-weighting in the GFR algorithm
1 parent f55bbb4 commit ddf9ea6

File tree

2 files changed

+167
-8
lines changed

2 files changed

+167
-8
lines changed

include/stochtree/tree_sampler.h

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -471,8 +471,8 @@ template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatCon
471471
static inline void EvaluateAllPossibleSplits(
472472
ForestDataset& dataset, ForestTracker& tracker, ColumnVector& residual, TreePrior& tree_prior, LeafModel& leaf_model, double global_variance, int tree_num, int split_node_id,
473473
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values, std::vector<FeatureType>& cutpoint_feature_types,
474-
data_size_t& valid_cutpoint_count, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end, std::vector<double>& variable_weights,
475-
std::vector<FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
474+
data_size_t& valid_cutpoint_count, std::vector<data_size_t>& feature_cutpoint_counts, CutpointGridContainer& cutpoint_grid_container, data_size_t node_begin, data_size_t node_end,
475+
std::vector<double>& variable_weights, std::vector<FeatureType>& feature_types, LeafSuffStatConstructorArgs&... leaf_suff_stat_args
476476
) {
477477
// Initialize sufficient statistics
478478
LeafSuffStat node_suff_stat = LeafSuffStat(leaf_suff_stat_args...);
@@ -496,6 +496,7 @@ static inline void EvaluateAllPossibleSplits(
496496
int32_t min_samples_in_leaf = tree_prior.GetMinSamplesLeaf();
497497

498498
// Compute sufficient statistics for each possible split
499+
data_size_t feature_cutpoints;
499500
data_size_t num_cutpoints = 0;
500501
bool valid_split = false;
501502
data_size_t node_row_iter;
@@ -509,6 +510,8 @@ static inline void EvaluateAllPossibleSplits(
509510
double log_split_eval = 0.0;
510511
double split_log_ml;
511512
for (int j = 0; j < covariates.cols(); j++) {
513+
// Reset feature cutpoint counter
514+
feature_cutpoints = 0;
512515

513516
if (std::abs(variable_weights.at(j)) > kEpsilon) {
514517
// Enumerate cutpoint strides
@@ -542,6 +545,7 @@ static inline void EvaluateAllPossibleSplits(
542545
valid_split = (left_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf) &&
543546
right_suff_stat.SampleGreaterThanEqual(min_samples_in_leaf));
544547
if (valid_split) {
548+
feature_cutpoints++;
545549
num_cutpoints++;
546550
// Add to split rule vector
547551
cutpoint_feature_types.push_back(feature_type);
@@ -553,7 +557,8 @@ static inline void EvaluateAllPossibleSplits(
553557
}
554558
}
555559
}
556-
560+
// Add feature_cutpoints to feature_cutpoint_counts
561+
feature_cutpoint_counts.push_back(feature_cutpoints);
557562
}
558563

559564
// Add the log marginal likelihood of the "no-split" option (adjusted for tree prior and cutpoint size per the XBART paper)
@@ -570,16 +575,38 @@ template <typename LeafModel, typename LeafSuffStat, typename... LeafSuffStatCon
570575
static inline void EvaluateCutpoints(Tree* tree, ForestTracker& tracker, LeafModel& leaf_model, ForestDataset& dataset, ColumnVector& residual, TreePrior& tree_prior,
571576
std::mt19937& gen, int tree_num, double global_variance, int cutpoint_grid_size, int node_id, data_size_t node_begin, data_size_t node_end,
572577
std::vector<double>& log_cutpoint_evaluations, std::vector<int>& cutpoint_features, std::vector<double>& cutpoint_values,
573-
std::vector<FeatureType>& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector<double>& variable_weights,
574-
std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container, LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
578+
std::vector<FeatureType>& cutpoint_feature_types, data_size_t& valid_cutpoint_count, std::vector<StochTree::data_size_t>& feature_cutpoint_counts,
579+
std::vector<double>& variable_weights, std::vector<FeatureType>& feature_types, CutpointGridContainer& cutpoint_grid_container,
580+
LeafSuffStatConstructorArgs&... leaf_suff_stat_args) {
575581
// Evaluate all possible cutpoints according to the leaf node model,
576582
// recording their log-likelihood and other split information in a series of vectors.
577583
// The last element of these vectors concerns the "no-split" option.
578584
EvaluateAllPossibleSplits<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
579585
dataset, tracker, residual, tree_prior, leaf_model, global_variance, tree_num, node_id, log_cutpoint_evaluations,
580-
cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, cutpoint_grid_container,
586+
cutpoint_features, cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, cutpoint_grid_container,
581587
node_begin, node_end, variable_weights, feature_types, leaf_suff_stat_args...
582588
);
589+
590+
// Compute weighting adjustments for low-cardinality categorical features
591+
// Check if the dataset has continuous features, ignore this adjustment if not
592+
bool has_continuous_features = false;
593+
int max_feature_cutpoint_count = 0;
594+
for (int j = 0; j < feature_types.size(); j++) {
595+
if (feature_types.at(j) == FeatureType::kNumeric) {
596+
has_continuous_features = true;
597+
if (feature_cutpoint_counts[j] > max_feature_cutpoint_count) max_feature_cutpoint_count = feature_cutpoint_counts[j];
598+
}
599+
}
600+
if (has_continuous_features) {
601+
double feature_weight;
602+
for (data_size_t i = 0; i < valid_cutpoint_count; i++) {
603+
// Determine whether the feature is categorical (and thus needs to be re-weighted)
604+
if ((cutpoint_feature_types[i] == FeatureType::kOrderedCategorical) || (cutpoint_feature_types[i] == FeatureType::kUnorderedCategorical)) {
605+
feature_weight = ((double) max_feature_cutpoint_count) / ((double) feature_cutpoint_counts[cutpoint_features[i]]);
606+
log_cutpoint_evaluations[i] += std::log(feature_weight);
607+
}
608+
}
609+
}
583610

584611
// Compute an adjustment to reflect the no split prior probability and the number of cutpoints
585612
double bart_prior_no_split_adj;
@@ -614,12 +641,13 @@ static inline void SampleSplitRule(Tree* tree, ForestTracker& tracker, LeafModel
614641
std::vector<double> cutpoint_values;
615642
std::vector<FeatureType> cutpoint_feature_types;
616643
StochTree::data_size_t valid_cutpoint_count;
644+
std::vector<StochTree::data_size_t> feature_cutpoint_counts;
617645
CutpointGridContainer cutpoint_grid_container(dataset.GetCovariates(), residual.GetData(), cutpoint_grid_size);
618646
EvaluateCutpoints<LeafModel, LeafSuffStat, LeafSuffStatConstructorArgs...>(
619647
tree, tracker, leaf_model, dataset, residual, tree_prior, gen, tree_num, global_variance,
620648
cutpoint_grid_size, node_id, node_begin, node_end, log_cutpoint_evaluations, cutpoint_features,
621-
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, variable_weights, feature_types,
622-
cutpoint_grid_container, leaf_suff_stat_args...
649+
cutpoint_values, cutpoint_feature_types, valid_cutpoint_count, feature_cutpoint_counts, variable_weights,
650+
feature_types, cutpoint_grid_container, leaf_suff_stat_args...
623651
);
624652
// TODO: maybe add some checks here?
625653

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
################################################################################
2+
## Comparison of GFR / warm start with pure MCMC on datasets with a
3+
## mix of numeric features and low-cardinality categorical features.
4+
################################################################################
5+
6+
# Load libraries
7+
library(stochtree)
8+
9+
# Generate data
10+
n <- 500
11+
p_continuous <- 5
12+
p_binary <- 2
13+
p_ordered_cat <- 2
14+
p <- p_continuous + p_binary + p_ordered_cat
15+
stopifnot(p_continuous >= 3)
16+
stopifnot(p_binary >= 2)
17+
stopifnot(p_ordered_cat >= 1)
18+
x_continuous <- matrix(
19+
runif(n*p_continuous),
20+
ncol = p_continuous
21+
)
22+
x_binary <- matrix(
23+
rbinom(n*p_binary, size = 1, prob = 0.5),
24+
ncol = p_binary
25+
)
26+
x_ordered_cat <- matrix(
27+
sample(1:5, size = n*p_ordered_cat, replace = T),
28+
ncol = p_ordered_cat
29+
)
30+
X_matrix <- cbind(x_continuous, x_binary, x_ordered_cat)
31+
X_df <- as.data.frame(X_matrix)
32+
colnames(X_df) <- paste0("x", 1:p)
33+
for (i in (p_continuous+1):(p_continuous+p_binary+p_ordered_cat)) {
34+
X_df[,i] <- factor(X_df[,i], ordered = T)
35+
}
36+
f_x_cont <- (2 + 4*x_continuous[,1] - 6*(x_continuous[,2] < 0) +
37+
6*(x_continuous[,2] >= 0) + 5*(abs(x_continuous[,3]) - sqrt(2/pi)))
38+
f_x_binary <- -1.5 + 1*x_binary[,1] + 2*x_binary[,2]
39+
f_x_ordered_cat <- 3 - 1*x_ordered_cat[,1]
40+
pct_var_cont <- 1/3
41+
pct_var_binary <- 1/3
42+
pct_var_ordered_cat <- 1/3
43+
stopifnot(pct_var_cont + pct_var_binary + pct_var_ordered_cat == 1.0)
44+
total_var <- var(f_x_cont+f_x_binary+f_x_ordered_cat)
45+
f_x_cont_rescaled <- f_x_cont * sqrt(
46+
pct_var_cont / (var(f_x_cont) / total_var)
47+
)
48+
f_x_binary_rescaled <- f_x_binary * sqrt(
49+
pct_var_binary / (var(f_x_binary) / total_var)
50+
)
51+
f_x_ordered_cat_rescaled <- f_x_ordered_cat * sqrt(
52+
pct_var_ordered_cat / (var(f_x_ordered_cat) / total_var)
53+
)
54+
E_y <- f_x_cont_rescaled + f_x_binary_rescaled + f_x_ordered_cat_rescaled
55+
# var(f_x_cont_rescaled) / var(E_y)
56+
# var(f_x_binary_rescaled) / var(E_y)
57+
# var(f_x_ordered_cat_rescaled) / var(E_y)
58+
snr <- 3
59+
epsilon <- rnorm(n, 0, 1) * sd(E_y) / snr
60+
y <- E_y + epsilon
61+
jitter_eps <- 0.1
62+
x_binary_jitter <- x_binary + matrix(
63+
runif(n*p_binary, -jitter_eps, jitter_eps), ncol = p_binary
64+
)
65+
x_ordered_cat_jitter <- x_ordered_cat + matrix(
66+
runif(n*p_ordered_cat, -jitter_eps, jitter_eps), ncol = p_ordered_cat
67+
)
68+
X_matrix_jitter <- cbind(x_continuous, x_binary_jitter, x_ordered_cat_jitter)
69+
X_df_jitter <- as.data.frame(X_matrix_jitter)
70+
colnames(X_df_jitter) <- paste0("x", 1:p)
71+
72+
# Test-train split
73+
test_set_pct <- 0.2
74+
n_test <- round(test_set_pct*n)
75+
n_train <- n - n_test
76+
test_inds <- sort(sample(1:n, n_test, replace = FALSE))
77+
train_inds <- (1:n)[!((1:n) %in% test_inds)]
78+
X_df_test <- X_df[test_inds,]
79+
X_df_train <- X_df[train_inds,]
80+
X_df_jitter_test <- X_df_jitter[test_inds,]
81+
X_df_jitter_train <- X_df_jitter[train_inds,]
82+
y_test <- y[test_inds]
83+
y_train <- y[train_inds]
84+
85+
# Fit BART with warmstart on the original data
86+
ws_bart_fit <- bart(X_train = X_df_train, y_train = y_train,
87+
X_test = X_df_test, num_gfr = 15,
88+
num_burnin = 0, num_mcmc = 100)
89+
90+
# Fit BART with MCMC only on the original data
91+
bart_fit <- bart(X_train = X_df_train, y_train = y_train,
92+
X_test = X_df_test, num_gfr = 0,
93+
num_burnin = 2000, num_mcmc = 100)
94+
95+
# Fit BART with warmstart on the jittered data
96+
ws_bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train,
97+
X_test = X_df_jitter_test, num_gfr = 15,
98+
num_burnin = 0, num_mcmc = 100)
99+
100+
# Fit BART with MCMC only on the jittered data
101+
bart_jitter_fit <- bart(X_train = X_df_jitter_train, y_train = y_train,
102+
X_test = X_df_jitter_test, num_gfr = 0,
103+
num_burnin = 2000, num_mcmc = 100)
104+
105+
# Compare the variable split counds
106+
ws_bart_fit$mean_forests$get_aggregate_split_counts(p)
107+
bart_fit$mean_forests$get_aggregate_split_counts(p)
108+
ws_bart_jitter_fit$mean_forests$get_aggregate_split_counts(p)
109+
bart_jitter_fit$mean_forests$get_aggregate_split_counts(p)
110+
111+
# Compute out-of-sample RMSE
112+
sqrt(mean((rowMeans(ws_bart_fit$y_hat_test) - y_test)^2))
113+
sqrt(mean((rowMeans(bart_fit$y_hat_test) - y_test)^2))
114+
sqrt(mean((rowMeans(ws_bart_jitter_fit$y_hat_test) - y_test)^2))
115+
sqrt(mean((rowMeans(bart_jitter_fit$y_hat_test) - y_test)^2))
116+
117+
# Compare sigma traceplots
118+
sigma_min <- min(c(ws_bart_fit$sigma2_global_samples,
119+
bart_fit$sigma2_global_samples,
120+
ws_bart_jitter_fit$sigma2_global_samples,
121+
bart_jitter_fit$sigma2_global_samples))
122+
sigma_max <- max(c(ws_bart_fit$sigma2_global_samples,
123+
bart_fit$sigma2_global_samples,
124+
ws_bart_jitter_fit$sigma2_global_samples,
125+
bart_jitter_fit$sigma2_global_samples))
126+
plot(ws_bart_fit$sigma2_global_samples,
127+
ylim = c(sigma_min - 0.1, sigma_max + 0.1),
128+
type = "line", col = "black")
129+
lines(bart_fit$sigma2_global_samples, col = "blue")
130+
lines(ws_bart_jitter_fit$sigma2_global_samples, col = "green")
131+
lines(bart_jitter_fit$sigma2_global_samples, col = "red")

0 commit comments

Comments
 (0)