Skip to content

Commit 2c3f544

Browse files
Fix remaining type problems in metropolis.py
1 parent 0121315 commit 2c3f544

File tree

4 files changed

+42
-57
lines changed

4 files changed

+42
-57
lines changed

pymc/aesaraf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from aeppl.logprob import CheckParameterValue
3636
from aeppl.transforms import RVTransform
3737
from aesara import scalar
38-
from aesara.compile.mode import Mode, get_mode
38+
from aesara.compile import Function, Mode, get_mode
3939
from aesara.gradient import grad
4040
from aesara.graph import node_rewriter, rewrite_graph
4141
from aesara.graph.basic import (
@@ -1044,7 +1044,7 @@ def compile_pymc(
10441044
random_seed: SeedSequenceSeed = None,
10451045
mode=None,
10461046
**kwargs,
1047-
) -> Callable[..., Union[np.ndarray, List[np.ndarray]]]:
1047+
) -> Function:
10481048
"""Use ``aesara.function`` with specialized pymc rewrites always enabled.
10491049
10501050
This function also ensures shared RandomState/Generator used by RandomVariables

pymc/step_methods/arraystep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ def step(self, point) -> Tuple[PointType, StatsType]:
273273
return super().step(point)
274274

275275

276-
def metrop_select(mr, q, q0):
276+
def metrop_select(mr: np.ndarray, q: np.ndarray, q0: np.ndarray) -> Tuple[np.ndarray, bool]:
277277
"""Perform rejection/acceptance step for Metropolis class samplers.
278278
279279
Returns the new sample q if a uniform random number is less than the

pymc/step_methods/metropolis.py

Lines changed: 38 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -68,38 +68,29 @@ def __init__(self, s):
6868

6969
class NormalProposal(Proposal):
7070
def __call__(self, rng: Optional[np.random.Generator] = None):
71-
if rng is None:
72-
rng = nr
73-
return rng.normal(scale=self.s)
71+
return (rng or nr).normal(scale=self.s)
7472

7573

7674
class UniformProposal(Proposal):
7775
def __call__(self, rng: Optional[np.random.Generator] = None):
78-
if rng is None:
79-
rng = nr
80-
return rng.uniform(low=-self.s, high=self.s, size=len(self.s))
76+
return (rng or nr).uniform(low=-self.s, high=self.s, size=len(self.s))
8177

8278

8379
class CauchyProposal(Proposal):
8480
def __call__(self, rng: Optional[np.random.Generator] = None):
85-
if rng is None:
86-
rng = nr
87-
return rng.standard_cauchy(size=np.size(self.s)) * self.s
81+
return (rng or nr).standard_cauchy(size=np.size(self.s)) * self.s
8882

8983

9084
class LaplaceProposal(Proposal):
9185
def __call__(self, rng: Optional[np.random.Generator] = None):
92-
if rng is None:
93-
rng = nr
9486
size = np.size(self.s)
95-
return (rng.standard_exponential(size=size) - rng.standard_exponential(size=size)) * self.s
87+
r = rng or nr
88+
return (r.standard_exponential(size=size) - r.standard_exponential(size=size)) * self.s
9689

9790

9891
class PoissonProposal(Proposal):
9992
def __call__(self, rng: Optional[np.random.Generator] = None):
100-
if rng is None:
101-
rng = nr
102-
return rng.poisson(lam=self.s, size=np.size(self.s)) - self.s
93+
return (rng or nr).poisson(lam=self.s, size=np.size(self.s)) - self.s
10394

10495

10596
class MultivariateNormalProposal(Proposal):
@@ -111,13 +102,12 @@ def __init__(self, s):
111102
self.chol = scipy.linalg.cholesky(s, lower=True)
112103

113104
def __call__(self, num_draws=None, rng: Optional[np.random.Generator] = None):
114-
if rng is None:
115-
rng = nr
105+
rng_ = rng or nr
116106
if num_draws is not None:
117-
b = rng.normal(size=(self.n, num_draws))
107+
b = rng_.normal(size=(self.n, num_draws))
118108
return np.dot(self.chol, b).T
119109
else:
120-
b = rng.normal(size=self.n)
110+
b = rng_.normal(size=self.n)
121111
return np.dot(self.chol, b)
122112

123113

@@ -247,7 +237,7 @@ def reset_tuning(self):
247237
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
248238

249239
point_map_info = q0.point_map_info
250-
q0 = q0.data
240+
q0d = q0.data
251241

252242
if not self.steps_until_tune and self.tune:
253243
# Tune scaling parameter
@@ -261,30 +251,30 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
261251
if self.any_discrete:
262252
if self.all_discrete:
263253
delta = np.round(delta, 0).astype("int64")
264-
q0 = q0.astype("int64")
265-
q = (q0 + delta).astype("int64")
254+
q0d = q0d.astype("int64")
255+
q = (q0d + delta).astype("int64")
266256
else:
267257
delta[self.discrete] = np.round(delta[self.discrete], 0)
268-
q = q0 + delta
258+
q = q0d + delta
269259
else:
270-
q = floatX(q0 + delta)
260+
q = floatX(q0d + delta)
271261

272262
if self.elemwise_update:
273-
q_temp = q0.copy()
263+
q_temp = q0d.copy()
274264
# Shuffle order of updates (probably we don't need to do this in every step)
275265
np.random.shuffle(self.enum_dims)
276266
for i in self.enum_dims:
277267
q_temp[i] = q[i]
278-
accept_rate_i = self.delta_logp(q_temp, q0)
279-
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0)
268+
accept_rate_i = self.delta_logp(q_temp, q0d)
269+
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d)
280270
q_temp[i] = q_temp_[i]
281271
self.accept_rate_iter[i] = accept_rate_i
282272
self.accepted_iter[i] = accepted_i
283273
self.accepted_sum[i] += accepted_i
284274
q = q_temp
285275
else:
286-
accept_rate = self.delta_logp(q, q0)
287-
q, accepted = metrop_select(accept_rate, q, q0)
276+
accept_rate = self.delta_logp(q, q0d)
277+
q, accepted = metrop_select(accept_rate, q, q0d)
288278
self.accept_rate_iter = accept_rate
289279
self.accepted_iter = accepted
290280
self.accepted_sum += accepted
@@ -399,11 +389,11 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
399389

400390
super().__init__(vars, [model.compile_logp()])
401391

402-
def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
392+
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
403393
logp = args[0]
404-
logp_q0 = logp(q0)
405-
point_map_info = q0.point_map_info
406-
q0 = q0.data
394+
logp_q0 = logp(apoint)
395+
point_map_info = apoint.point_map_info
396+
q0 = apoint.data
407397

408398
# Convert adaptive_scale_factor to a jump probability
409399
p_jump = 1.0 - 0.5**self.scaling
@@ -425,9 +415,7 @@ def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
425415
"p_jump": p_jump,
426416
}
427417

428-
q_new = RaveledVars(q_new, point_map_info)
429-
430-
return q_new, [stats]
418+
return RaveledVars(q_new, point_map_info), [stats]
431419

432420
@staticmethod
433421
def competence(var):
@@ -501,13 +489,13 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
501489

502490
super().__init__(vars, [model.compile_logp()])
503491

504-
def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
492+
def astep(self, apoint: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
505493
logp: Callable[[RaveledVars], np.ndarray] = args[0]
506494
order = self.order
507495
if self.shuffle_dims:
508496
nr.shuffle(order)
509497

510-
q = RaveledVars(np.copy(q0.data), q0.point_map_info)
498+
q = RaveledVars(np.copy(apoint.data), apoint.point_map_info)
511499

512500
logp_curr = logp(q)
513501

@@ -805,7 +793,7 @@ def __init__(
805793
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
806794

807795
point_map_info = q0.point_map_info
808-
q0 = q0.data
796+
q0d = q0.data
809797

810798
if not self.steps_until_tune and self.tune:
811799
if self.tune == "scaling":
@@ -824,10 +812,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
824812
r1 = DictToArrayBijection.map(self.population[ir1])
825813
r2 = DictToArrayBijection.map(self.population[ir2])
826814
# propose a jump
827-
q = floatX(q0 + self.lamb * (r1.data - r2.data) + epsilon)
815+
q = floatX(q0d + self.lamb * (r1.data - r2.data) + epsilon)
828816

829-
accept = self.delta_logp(q, q0)
830-
q_new, accepted = metrop_select(accept, q, q0)
817+
accept = self.delta_logp(q, q0d)
818+
q_new, accepted = metrop_select(accept, q, q0d)
831819
self.accepted += accepted
832820

833821
self.steps_until_tune -= 1
@@ -840,9 +828,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
840828
"accepted": accepted,
841829
}
842830

843-
q_new = RaveledVars(q_new, point_map_info)
844-
845-
return q_new, [stats]
831+
return RaveledVars(q_new, point_map_info), [stats]
846832

847833
@staticmethod
848834
def competence(var, has_grad):
@@ -948,7 +934,7 @@ def __init__(
948934
self.accepted = 0
949935

950936
# cache local history for the Z-proposals
951-
self._history = []
937+
self._history: List[np.ndarray] = []
952938
# remember initial settings before tuning so they can be reset
953939
self._untuned_settings = dict(
954940
scaling=self.scaling,
@@ -974,7 +960,7 @@ def reset_tuning(self):
974960
def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
975961

976962
point_map_info = q0.point_map_info
977-
q0 = q0.data
963+
q0d = q0.data
978964

979965
# same tuning scheme as DEMetropolis
980966
if not self.steps_until_tune and self.tune:
@@ -1001,13 +987,13 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
1001987
z1 = self._history[iz1]
1002988
z2 = self._history[iz2]
1003989
# propose a jump
1004-
q = floatX(q0 + self.lamb * (z1 - z2) + epsilon)
990+
q = floatX(q0d + self.lamb * (z1 - z2) + epsilon)
1005991
else:
1006992
# propose just with noise in the first 2 iterations
1007-
q = floatX(q0 + epsilon)
993+
q = floatX(q0d + epsilon)
1008994

1009-
accept = self.delta_logp(q, q0)
1010-
q_new, accepted = metrop_select(accept, q, q0)
995+
accept = self.delta_logp(q, q0d)
996+
q_new, accepted = metrop_select(accept, q, q0d)
1011997
self.accepted += accepted
1012998
self._history.append(q_new)
1013999

@@ -1021,9 +1007,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
10211007
"accepted": accepted,
10221008
}
10231009

1024-
q_new = RaveledVars(q_new, point_map_info)
1025-
1026-
return q_new, [stats]
1010+
return RaveledVars(q_new, point_map_info), [stats]
10271011

10281012
def stop_tuning(self):
10291013
"""At the end of the tuning phase, this method removes the first x% of the history

scripts/run_mypy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
pymc/step_methods/__init__.py
6565
pymc/step_methods/arraystep.py
6666
pymc/step_methods/compound.py
67+
pymc/step_methods/metropolis.py
6768
pymc/step_methods/hmc/__init__.py
6869
pymc/step_methods/hmc/base_hmc.py
6970
pymc/step_methods/hmc/hmc.py

0 commit comments

Comments
 (0)