@@ -68,38 +68,29 @@ def __init__(self, s):
68
68
69
69
class NormalProposal (Proposal ):
70
70
def __call__ (self , rng : Optional [np .random .Generator ] = None ):
71
- if rng is None :
72
- rng = nr
73
- return rng .normal (scale = self .s )
71
+ return (rng or nr ).normal (scale = self .s )
74
72
75
73
76
74
class UniformProposal (Proposal ):
77
75
def __call__ (self , rng : Optional [np .random .Generator ] = None ):
78
- if rng is None :
79
- rng = nr
80
- return rng .uniform (low = - self .s , high = self .s , size = len (self .s ))
76
+ return (rng or nr ).uniform (low = - self .s , high = self .s , size = len (self .s ))
81
77
82
78
83
79
class CauchyProposal (Proposal ):
84
80
def __call__ (self , rng : Optional [np .random .Generator ] = None ):
85
- if rng is None :
86
- rng = nr
87
- return rng .standard_cauchy (size = np .size (self .s )) * self .s
81
+ return (rng or nr ).standard_cauchy (size = np .size (self .s )) * self .s
88
82
89
83
90
84
class LaplaceProposal (Proposal ):
91
85
def __call__ (self , rng : Optional [np .random .Generator ] = None ):
92
- if rng is None :
93
- rng = nr
94
86
size = np .size (self .s )
95
- return (rng .standard_exponential (size = size ) - rng .standard_exponential (size = size )) * self .s
87
+ r = rng or nr
88
+ return (r .standard_exponential (size = size ) - r .standard_exponential (size = size )) * self .s
96
89
97
90
98
91
class PoissonProposal (Proposal ):
99
92
def __call__ (self , rng : Optional [np .random .Generator ] = None ):
100
- if rng is None :
101
- rng = nr
102
- return rng .poisson (lam = self .s , size = np .size (self .s )) - self .s
93
+ return (rng or nr ).poisson (lam = self .s , size = np .size (self .s )) - self .s
103
94
104
95
105
96
class MultivariateNormalProposal (Proposal ):
@@ -111,13 +102,12 @@ def __init__(self, s):
111
102
self .chol = scipy .linalg .cholesky (s , lower = True )
112
103
113
104
def __call__ (self , num_draws = None , rng : Optional [np .random .Generator ] = None ):
114
- if rng is None :
115
- rng = nr
105
+ rng_ = rng or nr
116
106
if num_draws is not None :
117
- b = rng .normal (size = (self .n , num_draws ))
107
+ b = rng_ .normal (size = (self .n , num_draws ))
118
108
return np .dot (self .chol , b ).T
119
109
else :
120
- b = rng .normal (size = self .n )
110
+ b = rng_ .normal (size = self .n )
121
111
return np .dot (self .chol , b )
122
112
123
113
@@ -247,7 +237,7 @@ def reset_tuning(self):
247
237
def astep (self , q0 : RaveledVars ) -> Tuple [RaveledVars , StatsType ]:
248
238
249
239
point_map_info = q0 .point_map_info
250
- q0 = q0 .data
240
+ q0d = q0 .data
251
241
252
242
if not self .steps_until_tune and self .tune :
253
243
# Tune scaling parameter
@@ -261,30 +251,30 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
261
251
if self .any_discrete :
262
252
if self .all_discrete :
263
253
delta = np .round (delta , 0 ).astype ("int64" )
264
- q0 = q0 .astype ("int64" )
265
- q = (q0 + delta ).astype ("int64" )
254
+ q0d = q0d .astype ("int64" )
255
+ q = (q0d + delta ).astype ("int64" )
266
256
else :
267
257
delta [self .discrete ] = np .round (delta [self .discrete ], 0 )
268
- q = q0 + delta
258
+ q = q0d + delta
269
259
else :
270
- q = floatX (q0 + delta )
260
+ q = floatX (q0d + delta )
271
261
272
262
if self .elemwise_update :
273
- q_temp = q0 .copy ()
263
+ q_temp = q0d .copy ()
274
264
# Shuffle order of updates (probably we don't need to do this in every step)
275
265
np .random .shuffle (self .enum_dims )
276
266
for i in self .enum_dims :
277
267
q_temp [i ] = q [i ]
278
- accept_rate_i = self .delta_logp (q_temp , q0 )
279
- q_temp_ , accepted_i = metrop_select (accept_rate_i , q_temp , q0 )
268
+ accept_rate_i = self .delta_logp (q_temp , q0d )
269
+ q_temp_ , accepted_i = metrop_select (accept_rate_i , q_temp , q0d )
280
270
q_temp [i ] = q_temp_ [i ]
281
271
self .accept_rate_iter [i ] = accept_rate_i
282
272
self .accepted_iter [i ] = accepted_i
283
273
self .accepted_sum [i ] += accepted_i
284
274
q = q_temp
285
275
else :
286
- accept_rate = self .delta_logp (q , q0 )
287
- q , accepted = metrop_select (accept_rate , q , q0 )
276
+ accept_rate = self .delta_logp (q , q0d )
277
+ q , accepted = metrop_select (accept_rate , q , q0d )
288
278
self .accept_rate_iter = accept_rate
289
279
self .accepted_iter = accepted
290
280
self .accepted_sum += accepted
@@ -399,11 +389,11 @@ def __init__(self, vars, scaling=1.0, tune=True, tune_interval=100, model=None):
399
389
400
390
super ().__init__ (vars , [model .compile_logp ()])
401
391
402
- def astep (self , q0 : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
392
+ def astep (self , apoint : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
403
393
logp = args [0 ]
404
- logp_q0 = logp (q0 )
405
- point_map_info = q0 .point_map_info
406
- q0 = q0 .data
394
+ logp_q0 = logp (apoint )
395
+ point_map_info = apoint .point_map_info
396
+ q0 = apoint .data
407
397
408
398
# Convert adaptive_scale_factor to a jump probability
409
399
p_jump = 1.0 - 0.5 ** self .scaling
@@ -425,9 +415,7 @@ def astep(self, q0: RaveledVars, *args) -> Tuple[RaveledVars, StatsType]:
425
415
"p_jump" : p_jump ,
426
416
}
427
417
428
- q_new = RaveledVars (q_new , point_map_info )
429
-
430
- return q_new , [stats ]
418
+ return RaveledVars (q_new , point_map_info ), [stats ]
431
419
432
420
@staticmethod
433
421
def competence (var ):
@@ -501,13 +489,13 @@ def __init__(self, vars, order="random", transit_p=0.8, model=None):
501
489
502
490
super ().__init__ (vars , [model .compile_logp ()])
503
491
504
- def astep (self , q0 : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
492
+ def astep (self , apoint : RaveledVars , * args ) -> Tuple [RaveledVars , StatsType ]:
505
493
logp : Callable [[RaveledVars ], np .ndarray ] = args [0 ]
506
494
order = self .order
507
495
if self .shuffle_dims :
508
496
nr .shuffle (order )
509
497
510
- q = RaveledVars (np .copy (q0 .data ), q0 .point_map_info )
498
+ q = RaveledVars (np .copy (apoint .data ), apoint .point_map_info )
511
499
512
500
logp_curr = logp (q )
513
501
@@ -805,7 +793,7 @@ def __init__(
805
793
def astep (self , q0 : RaveledVars ) -> Tuple [RaveledVars , StatsType ]:
806
794
807
795
point_map_info = q0 .point_map_info
808
- q0 = q0 .data
796
+ q0d = q0 .data
809
797
810
798
if not self .steps_until_tune and self .tune :
811
799
if self .tune == "scaling" :
@@ -824,10 +812,10 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
824
812
r1 = DictToArrayBijection .map (self .population [ir1 ])
825
813
r2 = DictToArrayBijection .map (self .population [ir2 ])
826
814
# propose a jump
827
- q = floatX (q0 + self .lamb * (r1 .data - r2 .data ) + epsilon )
815
+ q = floatX (q0d + self .lamb * (r1 .data - r2 .data ) + epsilon )
828
816
829
- accept = self .delta_logp (q , q0 )
830
- q_new , accepted = metrop_select (accept , q , q0 )
817
+ accept = self .delta_logp (q , q0d )
818
+ q_new , accepted = metrop_select (accept , q , q0d )
831
819
self .accepted += accepted
832
820
833
821
self .steps_until_tune -= 1
@@ -840,9 +828,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
840
828
"accepted" : accepted ,
841
829
}
842
830
843
- q_new = RaveledVars (q_new , point_map_info )
844
-
845
- return q_new , [stats ]
831
+ return RaveledVars (q_new , point_map_info ), [stats ]
846
832
847
833
@staticmethod
848
834
def competence (var , has_grad ):
@@ -948,7 +934,7 @@ def __init__(
948
934
self .accepted = 0
949
935
950
936
# cache local history for the Z-proposals
951
- self ._history = []
937
+ self ._history : List [ np . ndarray ] = []
952
938
# remember initial settings before tuning so they can be reset
953
939
self ._untuned_settings = dict (
954
940
scaling = self .scaling ,
@@ -974,7 +960,7 @@ def reset_tuning(self):
974
960
def astep (self , q0 : RaveledVars ) -> Tuple [RaveledVars , StatsType ]:
975
961
976
962
point_map_info = q0 .point_map_info
977
- q0 = q0 .data
963
+ q0d = q0 .data
978
964
979
965
# same tuning scheme as DEMetropolis
980
966
if not self .steps_until_tune and self .tune :
@@ -1001,13 +987,13 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
1001
987
z1 = self ._history [iz1 ]
1002
988
z2 = self ._history [iz2 ]
1003
989
# propose a jump
1004
- q = floatX (q0 + self .lamb * (z1 - z2 ) + epsilon )
990
+ q = floatX (q0d + self .lamb * (z1 - z2 ) + epsilon )
1005
991
else :
1006
992
# propose just with noise in the first 2 iterations
1007
- q = floatX (q0 + epsilon )
993
+ q = floatX (q0d + epsilon )
1008
994
1009
- accept = self .delta_logp (q , q0 )
1010
- q_new , accepted = metrop_select (accept , q , q0 )
995
+ accept = self .delta_logp (q , q0d )
996
+ q_new , accepted = metrop_select (accept , q , q0d )
1011
997
self .accepted += accepted
1012
998
self ._history .append (q_new )
1013
999
@@ -1021,9 +1007,7 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
1021
1007
"accepted" : accepted ,
1022
1008
}
1023
1009
1024
- q_new = RaveledVars (q_new , point_map_info )
1025
-
1026
- return q_new , [stats ]
1010
+ return RaveledVars (q_new , point_map_info ), [stats ]
1027
1011
1028
1012
def stop_tuning (self ):
1029
1013
"""At the end of the tuning phase, this method removes the first x% of the history
0 commit comments