@@ -352,7 +352,7 @@ def __init__(self, *args, n_steps=25, tune_steps=True, p_acc_rate=0.85, **kwargs
352
352
self .p_acc_rate = p_acc_rate
353
353
354
354
self .max_steps = n_steps
355
- self .proposed = self .proposed = self . draws * self .n_steps
355
+ self .proposed = self .draws * self .n_steps
356
356
self .proposal_dist = None
357
357
self .acc_rate = None
358
358
@@ -431,16 +431,26 @@ def sample_settings(self):
431
431
class MH (SMC_KERNEL ):
432
432
"""Metropolis-Hastings SMC kernel"""
433
433
434
- def __init__ (self , * args , n_steps = 25 , ** kwargs ):
434
+ def __init__ (self , * args , n_steps = 25 , tune_steps = True , p_acc_rate = 0.85 , ** kwargs ):
435
435
"""
436
436
Parameters
437
437
----------
438
438
n_steps: int
439
439
The number of steps of each Markov Chain.
440
+ tune_steps: bool
441
+ Whether to compute the number of steps automatically or not. Defaults to True
442
+ p_acc_rate: float
443
+ Used to compute ``n_steps`` when ``tune_steps == True``. The higher the value of
444
+ ``p_acc_rate`` the higher the number of steps computed automatically. Defaults to 0.85.
445
+ It should be between 0 and 1.
440
446
"""
441
447
super ().__init__ (* args , ** kwargs )
442
448
self .n_steps = n_steps
449
+ self .tune_steps = tune_steps
450
+ self .p_acc_rate = p_acc_rate
443
451
452
+ self .max_steps = n_steps
453
+ self .proposed = self .draws * self .n_steps
444
454
self .proposal_dist = None
445
455
self .proposal_scales = None
446
456
self .chain_acc_rate = None
@@ -460,13 +470,21 @@ def resample(self):
460
470
self .chain_acc_rate = self .chain_acc_rate [self .resampling_indexes ]
461
471
462
472
def tune (self ):
463
- """Update proposal scales for each particle dimension"""
473
+ """Update proposal scales for each particle dimension and update number of MH steps """
464
474
if self .iteration > 1 :
465
475
# Rescale based on distance to 0.234 acceptance rate
466
476
chain_scales = np .exp (np .log (self .proposal_scales ) + (self .chain_acc_rate - 0.234 ))
467
477
# Interpolate between individual and population scales
468
478
self .proposal_scales = 0.5 * chain_scales + 0.5 * chain_scales .mean ()
469
479
480
+ if self .tune_steps :
481
+ acc_rate = max (1.0 / self .proposed , self .chain_acc_rate .mean ())
482
+ self .n_steps = min (
483
+ self .max_steps ,
484
+ max (2 , int (np .log (1 - self .p_acc_rate ) / np .log (1 - acc_rate ))),
485
+ )
486
+ self .proposed = self .draws * self .n_steps
487
+
470
488
def mutate (self ):
471
489
"""Metropolis-Hastings perturbation."""
472
490
ac_ = np .empty ((self .n_steps , self .draws ))
@@ -506,6 +524,8 @@ def sample_settings(self):
506
524
stats .update (
507
525
{
508
526
"_n_tune" : self .n_steps , # Default property name used in `SamplerReport`
527
+ "tune_steps" : self .tune_steps ,
528
+ "p_acc_rate" : self .p_acc_rate ,
509
529
}
510
530
)
511
531
return stats
0 commit comments