32
32
_log = logging .getLogger ("pymc" )
33
33
34
34
35
+ class ParticleTree :
36
+ """
37
+ Particle tree
38
+ """
39
+
40
+ def __init__ (self , tree , log_weight , likelihood ):
41
+ self .tree = tree .copy () # keeps the tree that we care at the moment
42
+ self .expansion_nodes = [0 ]
43
+ self .log_weight = log_weight
44
+ self .old_likelihood_logp = likelihood
45
+ self .used_variates = []
46
+
47
+ def sample_tree_sequential (
48
+ self ,
49
+ ssv ,
50
+ available_predictors ,
51
+ prior_prob_leaf_node ,
52
+ X ,
53
+ missing_data ,
54
+ sum_trees_output ,
55
+ mean ,
56
+ m ,
57
+ normal ,
58
+ mu_std ,
59
+ ):
60
+ tree_grew = False
61
+ if self .expansion_nodes :
62
+ index_leaf_node = self .expansion_nodes .pop (0 )
63
+ # Probability that this node will remain a leaf node
64
+ prob_leaf = prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
65
+
66
+ if prob_leaf < np .random .random ():
67
+ tree_grew , index_selected_predictor = grow_tree (
68
+ self .tree ,
69
+ index_leaf_node ,
70
+ ssv ,
71
+ available_predictors ,
72
+ X ,
73
+ missing_data ,
74
+ sum_trees_output ,
75
+ mean ,
76
+ m ,
77
+ normal ,
78
+ mu_std ,
79
+ )
80
+ if tree_grew :
81
+ new_indexes = self .tree .idx_leaf_nodes [- 2 :]
82
+ self .expansion_nodes .extend (new_indexes )
83
+ self .used_variates .append (index_selected_predictor )
84
+
85
+ return tree_grew
86
+
87
+
35
88
class PGBART (ArrayStepShared ):
36
89
"""
37
90
Particle Gibss BART sampling step
@@ -152,7 +205,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
152
205
for t in range (self .max_stages ):
153
206
# Sample each particle (try to grow each tree), except for the first one.
154
207
for p in particles [1 :]:
155
- clp = p .sample_tree_sequential (
208
+ tree_grew = p .sample_tree_sequential (
156
209
self .ssv ,
157
210
self .available_predictors ,
158
211
self .prior_prob_leaf_node ,
@@ -164,7 +217,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
164
217
self .normal ,
165
218
self .mu_std ,
166
219
)
167
- if clp : # update weights only if p has changed from the previous iteration
220
+ if tree_grew :
168
221
self .update_weight (p )
169
222
# Normalize weights
170
223
W_t , normalized_weights = self .normalize (particles )
@@ -222,7 +275,7 @@ def competence(var, has_grad):
222
275
return Competence .IDEAL
223
276
return Competence .INCOMPATIBLE
224
277
225
- def normalize (self , particles ) :
278
+ def normalize (self , particles : List [ ParticleTree ]) -> Tuple [ float , np . ndarray ] :
226
279
"""
227
280
Use logsumexp trick to get W_t and softmax to get normalized_weights
228
281
"""
@@ -238,7 +291,7 @@ def normalize(self, particles):
238
291
239
292
return W_t , normalized_weights
240
293
241
- def init_particles (self , tree_id ) :
294
+ def init_particles (self , tree_id : int ) -> np . ndarray :
242
295
"""
243
296
Initialize particles
244
297
"""
@@ -259,12 +312,12 @@ def init_particles(self, tree_id):
259
312
260
313
return np .array (particles )
261
314
262
- def update_weight (self , particle ) :
315
+ def update_weight (self , particle : List [ ParticleTree ]) -> None :
263
316
"""
264
317
Update the weight of a particle
265
318
266
319
Since the prior is used as the proposal,the weights are updated additively as the ratio of
267
- the new and old log_likelihoods .
320
+ the new and old log-likelihoods .
268
321
"""
269
322
new_likelihood = self .likelihood_logp (
270
323
self .sum_trees_output_noi + particle .tree .predict_output ()
@@ -273,59 +326,6 @@ def update_weight(self, particle):
273
326
particle .old_likelihood_logp = new_likelihood
274
327
275
328
276
- class ParticleTree :
277
- """
278
- Particle tree
279
- """
280
-
281
- def __init__ (self , tree , log_weight , likelihood ):
282
- self .tree = tree .copy () # keeps the tree that we care at the moment
283
- self .expansion_nodes = [0 ]
284
- self .log_weight = log_weight
285
- self .old_likelihood_logp = likelihood
286
- self .used_variates = []
287
-
288
- def sample_tree_sequential (
289
- self ,
290
- ssv ,
291
- available_predictors ,
292
- prior_prob_leaf_node ,
293
- X ,
294
- missing_data ,
295
- sum_trees_output ,
296
- mean ,
297
- m ,
298
- normal ,
299
- mu_std ,
300
- ):
301
- clp = False
302
- if self .expansion_nodes :
303
- index_leaf_node = self .expansion_nodes .pop (0 )
304
- # Probability that this node will remain a leaf node
305
- prob_leaf = prior_prob_leaf_node [self .tree [index_leaf_node ].depth ]
306
-
307
- if prob_leaf < np .random .random ():
308
- clp , index_selected_predictor = grow_tree (
309
- self .tree ,
310
- index_leaf_node ,
311
- ssv ,
312
- available_predictors ,
313
- X ,
314
- missing_data ,
315
- sum_trees_output ,
316
- mean ,
317
- m ,
318
- normal ,
319
- mu_std ,
320
- )
321
- if clp :
322
- new_indexes = self .tree .idx_leaf_nodes [- 2 :]
323
- self .expansion_nodes .extend (new_indexes )
324
- self .used_variates .append (index_selected_predictor )
325
-
326
- return clp
327
-
328
-
329
329
def preprocess_XY (X , Y ):
330
330
if isinstance (Y , (Series , DataFrame )):
331
331
Y = Y .to_numpy ()
0 commit comments