Skip to content

BART: clamp first particle to old full tree #5011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 29, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 83 additions & 95 deletions pymc/step_methods/pgbart.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,59 @@
_log = logging.getLogger("pymc")


class ParticleTree:
"""
Particle tree
"""

def __init__(self, tree, log_weight, likelihood):
self.tree = tree.copy() # keeps the tree that we care at the moment
self.expansion_nodes = [0]
self.log_weight = log_weight
self.old_likelihood_logp = likelihood
self.used_variates = []

def sample_tree_sequential(
self,
ssv,
available_predictors,
prior_prob_leaf_node,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
):
tree_grew = False
if self.expansion_nodes:
index_leaf_node = self.expansion_nodes.pop(0)
# Probability that this node will remain a leaf node
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]

if prob_leaf < np.random.random():
tree_grew, index_selected_predictor = grow_tree(
self.tree,
index_leaf_node,
ssv,
available_predictors,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
)
if tree_grew:
new_indexes = self.tree.idx_leaf_nodes[-2:]
self.expansion_nodes.extend(new_indexes)
self.used_variates.append(index_selected_predictor)

return tree_grew


class PGBART(ArrayStepShared):
"""
Particle Gibss BART sampling step
Expand Down Expand Up @@ -108,9 +161,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo

if self.chunk == "auto":
self.chunk = max(1, int(self.m * 0.1))
self.num_particles = num_particles
self.log_num_particles = np.log(num_particles)
self.indices = list(range(1, num_particles))
self.len_indices = len(self.indices)
self.max_stages = max_stages

shared = make_shared_replacements(initial_values, vars, model)
Expand All @@ -137,24 +190,22 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
if self.idx == self.m:
self.idx = 0

for idx in range(self.idx, self.idx + self.chunk):
if idx >= self.m:
for tree_id in range(self.idx, self.idx + self.chunk):
if tree_id >= self.m:
break
tree = self.all_particles[idx].tree
sum_trees_output_noi = sum_trees_output - tree.predict_output()
self.idx += 1
# Generate an initial set of SMC particles
# at the end of the algorithm we return one of these particles as the new tree
particles = self.init_particles(tree.tree_id)
particles = self.init_particles(tree_id)
# Compute the sum of trees without the tree we are attempting to replace
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
self.idx += 1

# The old tree is not growing so we update the weights only once.
self.update_weight(particles[0])
for t in range(self.max_stages):
# Get old particle at stage t
if t > 0:
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
# sample each particle (try to grow each tree)
compute_logp = [True]
# Sample each particle (try to grow each tree), except for the first one.
for p in particles[1:]:
clp = p.sample_tree_sequential(
tree_grew = p.sample_tree_sequential(
self.ssv,
self.available_predictors,
self.prior_prob_leaf_node,
Expand All @@ -166,22 +217,14 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
self.normal,
self.mu_std,
)
compute_logp.append(clp)
# Update weights. Since the prior is used as the proposal,the weights
# are updated additively as the ratio of the new and old log_likelihoods
for clp, p in zip(compute_logp, particles):
if clp: # Compute the likelihood when p has changed from the previous iteration
new_likelihood = self.likelihood_logp(
sum_trees_output_noi + p.tree.predict_output()
)
p.log_weight += new_likelihood - p.old_likelihood_logp
p.old_likelihood_logp = new_likelihood
if tree_grew:
self.update_weight(p)
# Normalize weights
W_t, normalized_weights = self.normalize(particles)

# Resample all but first particle
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
particles[1:] = particles[new_indices]

# Set the new weights
Expand All @@ -200,8 +243,8 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
new_particle = np.random.choice(particles, p=normalized_weights)
new_tree = new_particle.tree
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
self.all_particles[tree.tree_id] = new_particle
sum_trees_output = sum_trees_output_noi + new_tree.predict_output()
self.all_particles[tree_id] = new_particle
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()

if self.tune:
for index in new_particle.used_variates:
Expand Down Expand Up @@ -232,7 +275,7 @@ def competence(var, has_grad):
return Competence.IDEAL
return Competence.INCOMPATIBLE

def normalize(self, particles):
def normalize(self, particles: List[ParticleTree]) -> Tuple[float, np.ndarray]:
"""
Use logsumexp trick to get W_t and softmax to get normalized_weights
"""
Expand All @@ -248,16 +291,11 @@ def normalize(self, particles):

return W_t, normalized_weights

def get_old_tree_particle(self, tree_id, t):
old_tree_particle = self.all_particles[tree_id]
old_tree_particle.set_particle_to_step(t)
return old_tree_particle

def init_particles(self, tree_id):
def init_particles(self, tree_id: int) -> np.ndarray:
"""
Initialize particles
"""
p = self.get_old_tree_particle(tree_id, 0)
p = self.all_particles[tree_id]
p.log_weight = self.init_log_weight
p.old_likelihood_logp = self.init_likelihood
particles = [p]
Expand All @@ -274,68 +312,18 @@ def init_particles(self, tree_id):

return np.array(particles)

def update_weight(self, particle: List[ParticleTree]) -> None:
"""
Update the weight of a particle

class ParticleTree:
"""
Particle tree
"""

def __init__(self, tree, log_weight, likelihood):
self.tree = tree.copy() # keeps the tree that we care at the moment
self.expansion_nodes = [0]
self.tree_history = [self.tree]
self.expansion_nodes_history = [self.expansion_nodes]
self.log_weight = log_weight
self.old_likelihood_logp = likelihood
self.used_variates = []

def sample_tree_sequential(
self,
ssv,
available_predictors,
prior_prob_leaf_node,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
):
clp = False
if self.expansion_nodes:
index_leaf_node = self.expansion_nodes.pop(0)
# Probability that this node will remain a leaf node
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]

if prob_leaf < np.random.random():
clp, index_selected_predictor = grow_tree(
self.tree,
index_leaf_node,
ssv,
available_predictors,
X,
missing_data,
sum_trees_output,
mean,
m,
normal,
mu_std,
)
if clp:
new_indexes = self.tree.idx_leaf_nodes[-2:]
self.expansion_nodes.extend(new_indexes)
self.used_variates.append(index_selected_predictor)

self.tree_history.append(self.tree)
self.expansion_nodes_history.append(self.expansion_nodes)
return clp

def set_particle_to_step(self, t):
if len(self.tree_history) <= t:
t = -1
self.tree = self.tree_history[t]
self.expansion_nodes = self.expansion_nodes_history[t]
Since the prior is used as the proposal,the weights are updated additively as the ratio of
the new and old log-likelihoods.
"""
new_likelihood = self.likelihood_logp(
self.sum_trees_output_noi + particle.tree.predict_output()
)
particle.log_weight += new_likelihood - particle.old_likelihood_logp
particle.old_likelihood_logp = new_likelihood


def preprocess_XY(X, Y):
Expand Down