Skip to content

Commit a16afa7

Browse files
aloctavodiaricardoV94
authored andcommitted
SMC: add option to tune n_steps in MH kernel
1 parent 129c1c8 commit a16afa7

File tree

2 files changed

+43
-3
lines changed

2 files changed

+43
-3
lines changed

pymc3/smc/smc.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def __init__(self, *args, n_steps=25, tune_steps=True, p_acc_rate=0.85, **kwargs
352352
self.p_acc_rate = p_acc_rate
353353

354354
self.max_steps = n_steps
355-
self.proposed = self.proposed = self.draws * self.n_steps
355+
self.proposed = self.draws * self.n_steps
356356
self.proposal_dist = None
357357
self.acc_rate = None
358358

@@ -431,16 +431,26 @@ def sample_settings(self):
431431
class MH(SMC_KERNEL):
432432
"""Metropolis-Hastings SMC kernel"""
433433

434-
def __init__(self, *args, n_steps=25, **kwargs):
434+
def __init__(self, *args, n_steps=25, tune_steps=True, p_acc_rate=0.85, **kwargs):
435435
"""
436436
Parameters
437437
----------
438438
n_steps: int
439439
The number of steps of each Markov Chain.
440+
tune_steps: bool
441+
Whether to compute the number of steps automatically or not. Defaults to True
442+
p_acc_rate: float
443+
Used to compute ``n_steps`` when ``tune_steps == True``. The higher the value of
444+
``p_acc_rate`` the higher the number of steps computed automatically. Defaults to 0.85.
445+
It should be between 0 and 1.
440446
"""
441447
super().__init__(*args, **kwargs)
442448
self.n_steps = n_steps
449+
self.tune_steps = tune_steps
450+
self.p_acc_rate = p_acc_rate
443451

452+
self.max_steps = n_steps
453+
self.proposed = self.draws * self.n_steps
444454
self.proposal_dist = None
445455
self.proposal_scales = None
446456
self.chain_acc_rate = None
@@ -460,13 +470,21 @@ def resample(self):
460470
self.chain_acc_rate = self.chain_acc_rate[self.resampling_indexes]
461471

462472
def tune(self):
463-
"""Update proposal scales for each particle dimension"""
473+
"""Update proposal scales for each particle dimension and update number of MH steps"""
464474
if self.iteration > 1:
465475
# Rescale based on distance to 0.234 acceptance rate
466476
chain_scales = np.exp(np.log(self.proposal_scales) + (self.chain_acc_rate - 0.234))
467477
# Interpolate between individual and population scales
468478
self.proposal_scales = 0.5 * chain_scales + 0.5 * chain_scales.mean()
469479

480+
if self.tune_steps:
481+
acc_rate = max(1.0 / self.proposed, self.chain_acc_rate.mean())
482+
self.n_steps = min(
483+
self.max_steps,
484+
max(2, int(np.log(1 - self.p_acc_rate) / np.log(1 - acc_rate))),
485+
)
486+
self.proposed = self.draws * self.n_steps
487+
470488
def mutate(self):
471489
"""Metropolis-Hastings perturbation."""
472490
ac_ = np.empty((self.n_steps, self.draws))
@@ -506,6 +524,8 @@ def sample_settings(self):
506524
stats.update(
507525
{
508526
"_n_tune": self.n_steps, # Default property name used in `SamplerReport`
527+
"tune_steps": self.tune_steps,
528+
"p_acc_rate": self.p_acc_rate,
509529
}
510530
)
511531
return stats

pymc3/tests/test_smc.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,33 @@ def test_kernel_kwargs(self):
149149
tune_steps=False,
150150
p_acc_rate=0.5,
151151
return_inferencedata=False,
152+
kernel=pm.smc.IMH,
152153
)
154+
153155
assert trace.report.threshold == 0.7
154156
assert trace.report.n_draws == 10
155157
assert trace.report.n_tune == 15
156158
assert trace.report.tune_steps is False
157159
assert trace.report.p_acc_rate == 0.5
158160

161+
with self.fast_model:
162+
trace = pm.sample_smc(
163+
draws=10,
164+
chains=1,
165+
threshold=0.95,
166+
n_steps=15,
167+
tune_steps=False,
168+
p_acc_rate=0.5,
169+
return_inferencedata=False,
170+
kernel=pm.smc.MH,
171+
)
172+
173+
assert trace.report.threshold == 0.95
174+
assert trace.report.n_draws == 10
175+
assert trace.report.n_tune == 15
176+
assert trace.report.tune_steps is False
177+
assert trace.report.p_acc_rate == 0.5
178+
159179
@pytest.mark.parametrize("chains", (1, 2))
160180
def test_return_datatype(self, chains):
161181
draws = 10

0 commit comments

Comments
 (0)