Skip to content

Commit 129c1c8

Browse files
committed
SMC: Fix proposal_dist shape in MH kernel
1 parent 9e15b20 commit 129c1c8

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

pymc3/smc/smc.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pymc3.blocking import DictToArrayBijection
3535
from pymc3.model import Point, modelcontext
3636
from pymc3.sampling import sample_prior_predictive
37+
from pymc3.step_methods.metropolis import MultivariateNormalProposal
3738
from pymc3.vartypes import discrete_types
3839

3940

@@ -449,10 +450,7 @@ def setup_kernel(self):
449450
Dimension specific scaling is provided by self.proposal_scales and set in self.tune()
450451
"""
451452
ndim = self.tempered_posterior.shape[1]
452-
self.proposal_dist = multivariate_normal(
453-
mean=np.zeros(ndim),
454-
cov=np.eye(ndim),
455-
)
453+
self.proposal_dist = MultivariateNormalProposal(np.eye(ndim))
456454
self.proposal_scales = np.full(self.draws, min(1, 2.38 ** 2 / ndim))
457455

458456
def resample(self):
@@ -477,7 +475,7 @@ def mutate(self):
477475
for n_step in range(self.n_steps):
478476
proposal = floatX(
479477
self.tempered_posterior
480-
+ self.proposal_dist.rvs(size=self.draws) * self.proposal_scales[:, None]
478+
+ self.proposal_dist(num_draws=self.draws) * self.proposal_scales[:, None]
481479
)
482480
ll = np.array([self.likelihood_logp_func(prop) for prop in proposal])
483481
pl = np.array([self.prior_logp_func(prop) for prop in proposal])

pymc3/tests/test_smc.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,3 +343,14 @@ def test_normal_model(self):
343343
post = idata.posterior.stack(sample=("chain", "draw"))
344344
assert np.abs(post["mu"].mean() - 10) < 0.1
345345
assert np.abs(post["sigma"].mean() - 0.5) < 0.05
346+
347+
def test_proposal_dist_shape(self):
348+
with pm.Model() as m:
349+
x = pm.Normal("x", 0, 1)
350+
y = pm.Normal("y", x, 1, observed=0)
351+
trace = pm.sample_smc(
352+
draws=10,
353+
chains=1,
354+
kernel=pm.smc.MH,
355+
return_inferencedata=False,
356+
)

0 commit comments

Comments
 (0)