Skip to content

Commit bf9958d

Browse files
committed
clamp first particle to old full tree
1 parent 641b278 commit bf9958d

File tree

1 file changed

+29
-41
lines changed

1 file changed

+29
-41
lines changed

pymc/step_methods/pgbart.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
108108

109109
if self.chunk == "auto":
110110
self.chunk = max(1, int(self.m * 0.1))
111-
self.num_particles = num_particles
112111
self.log_num_particles = np.log(num_particles)
113112
self.indices = list(range(1, num_particles))
113+
self.len_indices = len(self.indices)
114114
self.max_stages = max_stages
115115

116116
shared = make_shared_replacements(initial_values, vars, model)
@@ -137,22 +137,20 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
137137
if self.idx == self.m:
138138
self.idx = 0
139139

140-
for idx in range(self.idx, self.idx + self.chunk):
141-
if idx >= self.m:
140+
for tree_id in range(self.idx, self.idx + self.chunk):
141+
if tree_id >= self.m:
142142
break
143-
tree = self.all_particles[idx].tree
144-
sum_trees_output_noi = sum_trees_output - tree.predict_output()
145-
self.idx += 1
146143
# Generate an initial set of SMC particles
147144
# at the end of the algorithm we return one of these particles as the new tree
148-
particles = self.init_particles(tree.tree_id)
145+
particles = self.init_particles(tree_id)
146+
# Compute the sum of trees without the tree we are attempting to replace
147+
self.sum_trees_output_noi = sum_trees_output - particles[0].tree.predict_output()
148+
self.idx += 1
149149

150+
# The old tree is not growing so we update the weights only once.
151+
self.update_weight(particles[0])
150152
for t in range(self.max_stages):
151-
# Get old particle at stage t
152-
if t > 0:
153-
particles[0] = self.get_old_tree_particle(tree.tree_id, t)
154-
# sample each particle (try to grow each tree)
155-
compute_logp = [True]
153+
# Sample each particle (try to grow each tree), except for the first one.
156154
for p in particles[1:]:
157155
clp = p.sample_tree_sequential(
158156
self.ssv,
@@ -166,22 +164,14 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
166164
self.normal,
167165
self.mu_std,
168166
)
169-
compute_logp.append(clp)
170-
# Update weights. Since the prior is used as the proposal,the weights
171-
# are updated additively as the ratio of the new and old log_likelihoods
172-
for clp, p in zip(compute_logp, particles):
173-
if clp: # Compute the likelihood when p has changed from the previous iteration
174-
new_likelihood = self.likelihood_logp(
175-
sum_trees_output_noi + p.tree.predict_output()
176-
)
177-
p.log_weight += new_likelihood - p.old_likelihood_logp
178-
p.old_likelihood_logp = new_likelihood
167+
if clp: # update weights only if p has changed from the previous iteration
168+
self.update_weight(p)
179169
# Normalize weights
180170
W_t, normalized_weights = self.normalize(particles)
181171

182172
# Resample all but first particle
183173
re_n_w = normalized_weights[1:] / normalized_weights[1:].sum()
184-
new_indices = np.random.choice(self.indices, size=len(self.indices), p=re_n_w)
174+
new_indices = np.random.choice(self.indices, size=self.len_indices, p=re_n_w)
185175
particles[1:] = particles[new_indices]
186176

187177
# Set the new weights
@@ -200,8 +190,8 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
200190
new_particle = np.random.choice(particles, p=normalized_weights)
201191
new_tree = new_particle.tree
202192
new_particle.log_weight = new_particle.old_likelihood_logp - self.log_num_particles
203-
self.all_particles[tree.tree_id] = new_particle
204-
sum_trees_output = sum_trees_output_noi + new_tree.predict_output()
193+
self.all_particles[tree_id] = new_particle
194+
sum_trees_output = self.sum_trees_output_noi + new_tree.predict_output()
205195

206196
if self.tune:
207197
for index in new_particle.used_variates:
@@ -248,16 +238,11 @@ def normalize(self, particles):
248238

249239
return W_t, normalized_weights
250240

251-
def get_old_tree_particle(self, tree_id, t):
252-
old_tree_particle = self.all_particles[tree_id]
253-
old_tree_particle.set_particle_to_step(t)
254-
return old_tree_particle
255-
256241
def init_particles(self, tree_id):
257242
"""
258243
Initialize particles
259244
"""
260-
p = self.get_old_tree_particle(tree_id, 0)
245+
p = self.all_particles[tree_id]
261246
p.log_weight = self.init_log_weight
262247
p.old_likelihood_logp = self.init_likelihood
263248
particles = [p]
@@ -274,6 +259,19 @@ def init_particles(self, tree_id):
274259

275260
return np.array(particles)
276261

262+
def update_weight(self, particle):
263+
"""
264+
Update the weight of a particle
265+
266+
Since the prior is used as the proposal,the weights are updated additively as the ratio of
267+
the new and old log_likelihoods.
268+
"""
269+
new_likelihood = self.likelihood_logp(
270+
self.sum_trees_output_noi + particle.tree.predict_output()
271+
)
272+
particle.log_weight += new_likelihood - particle.old_likelihood_logp
273+
particle.old_likelihood_logp = new_likelihood
274+
277275

278276
class ParticleTree:
279277
"""
@@ -283,8 +281,6 @@ class ParticleTree:
283281
def __init__(self, tree, log_weight, likelihood):
284282
self.tree = tree.copy() # keeps the tree that we care at the moment
285283
self.expansion_nodes = [0]
286-
self.tree_history = [self.tree]
287-
self.expansion_nodes_history = [self.expansion_nodes]
288284
self.log_weight = log_weight
289285
self.old_likelihood_logp = likelihood
290286
self.used_variates = []
@@ -327,16 +323,8 @@ def sample_tree_sequential(
327323
self.expansion_nodes.extend(new_indexes)
328324
self.used_variates.append(index_selected_predictor)
329325

330-
self.tree_history.append(self.tree)
331-
self.expansion_nodes_history.append(self.expansion_nodes)
332326
return clp
333327

334-
def set_particle_to_step(self, t):
335-
if len(self.tree_history) <= t:
336-
t = -1
337-
self.tree = self.tree_history[t]
338-
self.expansion_nodes = self.expansion_nodes_history[t]
339-
340328

341329
def preprocess_XY(X, Y):
342330
if isinstance(Y, (Series, DataFrame)):

0 commit comments

Comments
 (0)