File tree Expand file tree Collapse file tree 2 files changed +14
-5
lines changed Expand file tree Collapse file tree 2 files changed +14
-5
lines changed Original file line number Diff line number Diff line change 34
34
from pymc3 .blocking import DictToArrayBijection
35
35
from pymc3 .model import Point , modelcontext
36
36
from pymc3 .sampling import sample_prior_predictive
37
+ from pymc3 .step_methods .metropolis import MultivariateNormalProposal
37
38
from pymc3 .vartypes import discrete_types
38
39
39
40
@@ -449,10 +450,7 @@ def setup_kernel(self):
449
450
Dimension specific scaling is provided by self.proposal_scales and set in self.tune()
450
451
"""
451
452
ndim = self .tempered_posterior .shape [1 ]
452
- self .proposal_dist = multivariate_normal (
453
- mean = np .zeros (ndim ),
454
- cov = np .eye (ndim ),
455
- )
453
+ self .proposal_dist = MultivariateNormalProposal (np .eye (ndim ))
456
454
self .proposal_scales = np .full (self .draws , min (1 , 2.38 ** 2 / ndim ))
457
455
458
456
def resample (self ):
@@ -477,7 +475,7 @@ def mutate(self):
477
475
for n_step in range (self .n_steps ):
478
476
proposal = floatX (
479
477
self .tempered_posterior
480
- + self .proposal_dist . rvs ( size = self .draws ) * self .proposal_scales [:, None ]
478
+ + self .proposal_dist ( num_draws = self .draws ) * self .proposal_scales [:, None ]
481
479
)
482
480
ll = np .array ([self .likelihood_logp_func (prop ) for prop in proposal ])
483
481
pl = np .array ([self .prior_logp_func (prop ) for prop in proposal ])
Original file line number Diff line number Diff line change @@ -343,3 +343,14 @@ def test_normal_model(self):
343
343
post = idata .posterior .stack (sample = ("chain" , "draw" ))
344
344
assert np .abs (post ["mu" ].mean () - 10 ) < 0.1
345
345
assert np .abs (post ["sigma" ].mean () - 0.5 ) < 0.05
346
+
347
+ def test_proposal_dist_shape (self ):
348
+ with pm .Model () as m :
349
+ x = pm .Normal ("x" , 0 , 1 )
350
+ y = pm .Normal ("y" , x , 1 , observed = 0 )
351
+ trace = pm .sample_smc (
352
+ draws = 10 ,
353
+ chains = 1 ,
354
+ kernel = pm .smc .MH ,
355
+ return_inferencedata = False ,
356
+ )
You can’t perform that action at this time.
0 commit comments