Skip to content

Commit 38252f5

Browse files
committed
descriptive variable name, add type hints
1 parent c3dd43b commit 38252f5

File tree

1 file changed

+59
-59
lines changed

1 file changed

+59
-59
lines changed

pymc/step_methods/pgbart.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,59 @@
3232
_log = logging.getLogger("pymc")
3333

3434

35+
class ParticleTree:
36+
"""
37+
Particle tree
38+
"""
39+
40+
def __init__(self, tree, log_weight, likelihood):
41+
self.tree = tree.copy() # keeps the tree that we care at the moment
42+
self.expansion_nodes = [0]
43+
self.log_weight = log_weight
44+
self.old_likelihood_logp = likelihood
45+
self.used_variates = []
46+
47+
def sample_tree_sequential(
48+
self,
49+
ssv,
50+
available_predictors,
51+
prior_prob_leaf_node,
52+
X,
53+
missing_data,
54+
sum_trees_output,
55+
mean,
56+
m,
57+
normal,
58+
mu_std,
59+
):
60+
tree_grew = False
61+
if self.expansion_nodes:
62+
index_leaf_node = self.expansion_nodes.pop(0)
63+
# Probability that this node will remain a leaf node
64+
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
65+
66+
if prob_leaf < np.random.random():
67+
tree_grew, index_selected_predictor = grow_tree(
68+
self.tree,
69+
index_leaf_node,
70+
ssv,
71+
available_predictors,
72+
X,
73+
missing_data,
74+
sum_trees_output,
75+
mean,
76+
m,
77+
normal,
78+
mu_std,
79+
)
80+
if tree_grew:
81+
new_indexes = self.tree.idx_leaf_nodes[-2:]
82+
self.expansion_nodes.extend(new_indexes)
83+
self.used_variates.append(index_selected_predictor)
84+
85+
return tree_grew
86+
87+
3588
class PGBART(ArrayStepShared):
3689
"""
3790
Particle Gibss BART sampling step
@@ -152,7 +205,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
152205
for t in range(self.max_stages):
153206
# Sample each particle (try to grow each tree), except for the first one.
154207
for p in particles[1:]:
155-
clp = p.sample_tree_sequential(
208+
tree_grew = p.sample_tree_sequential(
156209
self.ssv,
157210
self.available_predictors,
158211
self.prior_prob_leaf_node,
@@ -164,7 +217,7 @@ def astep(self, q: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
164217
self.normal,
165218
self.mu_std,
166219
)
167-
if clp: # update weights only if p has changed from the previous iteration
220+
if tree_grew:
168221
self.update_weight(p)
169222
# Normalize weights
170223
W_t, normalized_weights = self.normalize(particles)
@@ -222,7 +275,7 @@ def competence(var, has_grad):
222275
return Competence.IDEAL
223276
return Competence.INCOMPATIBLE
224277

225-
def normalize(self, particles):
278+
def normalize(self, particles: List[ParticleTree]) -> Tuple[float, np.ndarray]:
226279
"""
227280
Use logsumexp trick to get W_t and softmax to get normalized_weights
228281
"""
@@ -238,7 +291,7 @@ def normalize(self, particles):
238291

239292
return W_t, normalized_weights
240293

241-
def init_particles(self, tree_id):
294+
def init_particles(self, tree_id: int) -> np.ndarray:
242295
"""
243296
Initialize particles
244297
"""
@@ -259,12 +312,12 @@ def init_particles(self, tree_id):
259312

260313
return np.array(particles)
261314

262-
def update_weight(self, particle):
315+
def update_weight(self, particle: List[ParticleTree]) -> None:
263316
"""
264317
Update the weight of a particle
265318
266319
Since the prior is used as the proposal,the weights are updated additively as the ratio of
267-
the new and old log_likelihoods.
320+
the new and old log-likelihoods.
268321
"""
269322
new_likelihood = self.likelihood_logp(
270323
self.sum_trees_output_noi + particle.tree.predict_output()
@@ -273,59 +326,6 @@ def update_weight(self, particle):
273326
particle.old_likelihood_logp = new_likelihood
274327

275328

276-
class ParticleTree:
277-
"""
278-
Particle tree
279-
"""
280-
281-
def __init__(self, tree, log_weight, likelihood):
282-
self.tree = tree.copy() # keeps the tree that we care at the moment
283-
self.expansion_nodes = [0]
284-
self.log_weight = log_weight
285-
self.old_likelihood_logp = likelihood
286-
self.used_variates = []
287-
288-
def sample_tree_sequential(
289-
self,
290-
ssv,
291-
available_predictors,
292-
prior_prob_leaf_node,
293-
X,
294-
missing_data,
295-
sum_trees_output,
296-
mean,
297-
m,
298-
normal,
299-
mu_std,
300-
):
301-
clp = False
302-
if self.expansion_nodes:
303-
index_leaf_node = self.expansion_nodes.pop(0)
304-
# Probability that this node will remain a leaf node
305-
prob_leaf = prior_prob_leaf_node[self.tree[index_leaf_node].depth]
306-
307-
if prob_leaf < np.random.random():
308-
clp, index_selected_predictor = grow_tree(
309-
self.tree,
310-
index_leaf_node,
311-
ssv,
312-
available_predictors,
313-
X,
314-
missing_data,
315-
sum_trees_output,
316-
mean,
317-
m,
318-
normal,
319-
mu_std,
320-
)
321-
if clp:
322-
new_indexes = self.tree.idx_leaf_nodes[-2:]
323-
self.expansion_nodes.extend(new_indexes)
324-
self.used_variates.append(index_selected_predictor)
325-
326-
return clp
327-
328-
329329
def preprocess_XY(X, Y):
330330
if isinstance(Y, (Series, DataFrame)):
331331
Y = Y.to_numpy()

0 commit comments

Comments
 (0)