@@ -108,9 +108,9 @@ def __init__(self, vars=None, num_particles=10, max_stages=100, chunk="auto", mo
108
108
109
109
if self .chunk == "auto" :
110
110
self .chunk = max (1 , int (self .m * 0.1 ))
111
- self .num_particles = num_particles
112
111
self .log_num_particles = np .log (num_particles )
113
112
self .indices = list (range (1 , num_particles ))
113
+ self .len_indices = len (self .indices )
114
114
self .max_stages = max_stages
115
115
116
116
shared = make_shared_replacements (initial_values , vars , model )
@@ -137,22 +137,20 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
137
137
if self .idx == self .m :
138
138
self .idx = 0
139
139
140
- for idx in range (self .idx , self .idx + self .chunk ):
141
- if idx >= self .m :
140
+ for tree_id in range (self .idx , self .idx + self .chunk ):
141
+ if tree_id >= self .m :
142
142
break
143
- tree = self .all_particles [idx ].tree
144
- sum_trees_output_noi = sum_trees_output - tree .predict_output ()
145
- self .idx += 1
146
143
# Generate an initial set of SMC particles
147
144
# at the end of the algorithm we return one of these particles as the new tree
148
- particles = self .init_particles (tree .tree_id )
145
+ particles = self .init_particles (tree_id )
146
+ # Compute the sum of trees without the tree we are attempting to replace
147
+ self .sum_trees_output_noi = sum_trees_output - particles [0 ].tree .predict_output ()
148
+ self .idx += 1
149
149
150
+ # The old tree is not growing so we update the weights only once.
151
+ self .update_weight (particles [0 ])
150
152
for t in range (self .max_stages ):
151
- # Get old particle at stage t
152
- if t > 0 :
153
- particles [0 ] = self .get_old_tree_particle (tree .tree_id , t )
154
- # sample each particle (try to grow each tree)
155
- compute_logp = [True ]
153
+ # Sample each particle (try to grow each tree), except for the first one.
156
154
for p in particles [1 :]:
157
155
clp = p .sample_tree_sequential (
158
156
self .ssv ,
@@ -166,22 +164,14 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
166
164
self .normal ,
167
165
self .mu_std ,
168
166
)
169
- compute_logp .append (clp )
170
- # Update weights. Since the prior is used as the proposal,the weights
171
- # are updated additively as the ratio of the new and old log_likelihoods
172
- for clp , p in zip (compute_logp , particles ):
173
- if clp : # Compute the likelihood when p has changed from the previous iteration
174
- new_likelihood = self .likelihood_logp (
175
- sum_trees_output_noi + p .tree .predict_output ()
176
- )
177
- p .log_weight += new_likelihood - p .old_likelihood_logp
178
- p .old_likelihood_logp = new_likelihood
167
+ if clp : # update weights only if p has changed from the previous iteration
168
+ self .update_weight (p )
179
169
# Normalize weights
180
170
W_t , normalized_weights = self .normalize (particles )
181
171
182
172
# Resample all but first particle
183
173
re_n_w = normalized_weights [1 :] / normalized_weights [1 :].sum ()
184
- new_indices = np .random .choice (self .indices , size = len ( self .indices ) , p = re_n_w )
174
+ new_indices = np .random .choice (self .indices , size = self .len_indices , p = re_n_w )
185
175
particles [1 :] = particles [new_indices ]
186
176
187
177
# Set the new weights
@@ -200,8 +190,8 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
200
190
new_particle = np .random .choice (particles , p = normalized_weights )
201
191
new_tree = new_particle .tree
202
192
new_particle .log_weight = new_particle .old_likelihood_logp - self .log_num_particles
203
- self .all_particles [tree . tree_id ] = new_particle
204
- sum_trees_output = sum_trees_output_noi + new_tree .predict_output ()
193
+ self .all_particles [tree_id ] = new_particle
194
+ sum_trees_output = self . sum_trees_output_noi + new_tree .predict_output ()
205
195
206
196
if self .tune :
207
197
for index in new_particle .used_variates :
@@ -248,16 +238,11 @@ def normalize(self, particles):
248
238
249
239
return W_t , normalized_weights
250
240
251
- def get_old_tree_particle (self , tree_id , t ):
252
- old_tree_particle = self .all_particles [tree_id ]
253
- old_tree_particle .set_particle_to_step (t )
254
- return old_tree_particle
255
-
256
241
def init_particles (self , tree_id ):
257
242
"""
258
243
Initialize particles
259
244
"""
260
- p = self .get_old_tree_particle ( tree_id , 0 )
245
+ p = self .all_particles [ tree_id ]
261
246
p .log_weight = self .init_log_weight
262
247
p .old_likelihood_logp = self .init_likelihood
263
248
particles = [p ]
@@ -274,6 +259,19 @@ def init_particles(self, tree_id):
274
259
275
260
return np .array (particles )
276
261
262
+ def update_weight (self , particle ):
263
+ """
264
+ Update the weight of a particle
265
+
266
+ Since the prior is used as the proposal,the weights are updated additively as the ratio of
267
+ the new and old log_likelihoods.
268
+ """
269
+ new_likelihood = self .likelihood_logp (
270
+ self .sum_trees_output_noi + particle .tree .predict_output ()
271
+ )
272
+ particle .log_weight += new_likelihood - particle .old_likelihood_logp
273
+ particle .old_likelihood_logp = new_likelihood
274
+
277
275
278
276
class ParticleTree :
279
277
"""
@@ -283,8 +281,6 @@ class ParticleTree:
283
281
def __init__ (self , tree , log_weight , likelihood ):
284
282
self .tree = tree .copy () # keeps the tree that we care at the moment
285
283
self .expansion_nodes = [0 ]
286
- self .tree_history = [self .tree ]
287
- self .expansion_nodes_history = [self .expansion_nodes ]
288
284
self .log_weight = log_weight
289
285
self .old_likelihood_logp = likelihood
290
286
self .used_variates = []
@@ -327,16 +323,8 @@ def sample_tree_sequential(
327
323
self .expansion_nodes .extend (new_indexes )
328
324
self .used_variates .append (index_selected_predictor )
329
325
330
- self .tree_history .append (self .tree )
331
- self .expansion_nodes_history .append (self .expansion_nodes )
332
326
return clp
333
327
334
- def set_particle_to_step (self , t ):
335
- if len (self .tree_history ) <= t :
336
- t = - 1
337
- self .tree = self .tree_history [t ]
338
- self .expansion_nodes = self .expansion_nodes_history [t ]
339
-
340
328
341
329
def preprocess_XY (X , Y ):
342
330
if isinstance (Y , (Series , DataFrame )):
0 commit comments