30
30
from .report import SamplerReport , merge_reports
31
31
from ..util import get_var_name
32
32
33
- logger = logging .getLogger (' pymc3' )
33
+ logger = logging .getLogger (" pymc3" )
34
34
35
35
36
36
class BackendError (Exception ):
@@ -75,10 +75,8 @@ def __init__(self, name, model=None, vars=None, test_point=None):
75
75
test_point_ .update (test_point )
76
76
test_point = test_point_
77
77
var_values = list (zip (self .varnames , self .fn (test_point )))
78
- self .var_shapes = {var : value .shape
79
- for var , value in var_values }
80
- self .var_dtypes = {var : value .dtype
81
- for var , value in var_values }
78
+ self .var_shapes = {var : value .shape for var , value in var_values }
79
+ self .var_dtypes = {var : value .dtype for var , value in var_values }
82
80
self .chain = None
83
81
self ._is_base_setup = False
84
82
self .sampler_vars = None
@@ -104,8 +102,7 @@ def _set_sampler_vars(self, sampler_vars):
104
102
for stats in sampler_vars :
105
103
for key , dtype in stats .items ():
106
104
if dtypes .setdefault (key , dtype ) != dtype :
107
- raise ValueError ("Sampler statistic %s appears with "
108
- "different types." % key )
105
+ raise ValueError ("Sampler statistic %s appears with " "different types." % key )
109
106
110
107
self .sampler_vars = sampler_vars
111
108
@@ -155,7 +152,7 @@ def __getitem__(self, idx):
155
152
try :
156
153
return self .point (int (idx ))
157
154
except (ValueError , TypeError ): # Passed variable or variable name.
158
- raise ValueError (' Can only index with slice or integer' )
155
+ raise ValueError (" Can only index with slice or integer" )
159
156
160
157
def __len__ (self ):
161
158
raise NotImplementedError
@@ -199,13 +196,13 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
199
196
if sampler_idx is not None :
200
197
return self ._get_sampler_stats (stat_name , sampler_idx , burn , thin )
201
198
202
- sampler_idxs = [i for i , s in enumerate (self .sampler_vars )
203
- if stat_name in s ]
199
+ sampler_idxs = [i for i , s in enumerate (self .sampler_vars ) if stat_name in s ]
204
200
if not sampler_idxs :
205
201
raise KeyError ("Unknown sampler stat %s" % stat_name )
206
202
207
- vals = np .stack ([self ._get_sampler_stats (stat_name , i , burn , thin )
208
- for i in sampler_idxs ], axis = - 1 )
203
+ vals = np .stack (
204
+ [self ._get_sampler_stats (stat_name , i , burn , thin ) for i in sampler_idxs ], axis = - 1
205
+ )
209
206
if vals .shape [- 1 ] == 1 :
210
207
return vals [..., 0 ]
211
208
else :
@@ -296,13 +293,12 @@ def __init__(self, straces):
296
293
297
294
self ._report = SamplerReport ()
298
295
for strace in straces :
299
- if hasattr (strace , ' _warnings' ):
296
+ if hasattr (strace , " _warnings" ):
300
297
self ._report ._add_warnings (strace ._warnings , strace .chain )
301
298
302
299
def __repr__ (self ):
303
- template = '<{}: {} chains, {} iterations, {} variables>'
304
- return template .format (self .__class__ .__name__ ,
305
- self .nchains , len (self ), len (self .varnames ))
300
+ template = "<{}: {} chains, {} iterations, {} variables>"
301
+ return template .format (self .__class__ .__name__ , self .nchains , len (self ), len (self .varnames ))
306
302
307
303
@property
308
304
def nchains (self ):
@@ -339,16 +335,17 @@ def __getitem__(self, idx):
339
335
var = get_var_name (var )
340
336
if var in self .varnames :
341
337
if var in self .stat_names :
342
- warnings .warn ("Attribute access on a trace object is ambigous. "
343
- "Sampler statistic and model variable share a name. Use "
344
- "trace.get_values or trace.get_sampler_stats." )
338
+ warnings .warn (
339
+ "Attribute access on a trace object is ambigous. "
340
+ "Sampler statistic and model variable share a name. Use "
341
+ "trace.get_values or trace.get_sampler_stats."
342
+ )
345
343
return self .get_values (var , burn = burn , thin = thin )
346
344
if var in self .stat_names :
347
345
return self .get_sampler_stats (var , burn = burn , thin = thin )
348
346
raise KeyError ("Unknown variable %s" % var )
349
347
350
- _attrs = {'_straces' , 'varnames' , 'chains' , 'stat_names' ,
351
- 'supports_sampler_stats' , '_report' }
348
+ _attrs = {"_straces" , "varnames" , "chains" , "stat_names" , "supports_sampler_stats" , "_report" }
352
349
353
350
def __getattr__ (self , name ):
354
351
# Avoid infinite recursion when called before __init__
@@ -359,14 +356,15 @@ def __getattr__(self, name):
359
356
name = get_var_name (name )
360
357
if name in self .varnames :
361
358
if name in self .stat_names :
362
- warnings .warn ("Attribute access on a trace object is ambigous. "
363
- "Sampler statistic and model variable share a name. Use "
364
- "trace.get_values or trace.get_sampler_stats." )
359
+ warnings .warn (
360
+ "Attribute access on a trace object is ambigous. "
361
+ "Sampler statistic and model variable share a name. Use "
362
+ "trace.get_values or trace.get_sampler_stats."
363
+ )
365
364
return self .get_values (name )
366
365
if name in self .stat_names :
367
366
return self .get_sampler_stats (name )
368
- raise AttributeError ("'{}' object has no attribute '{}'" .format (
369
- type (self ).__name__ , name ))
367
+ raise AttributeError ("'{}' object has no attribute '{}'" .format (type (self ).__name__ , name ))
370
368
371
369
def __len__ (self ):
372
370
chain = self .chains [- 1 ]
@@ -425,10 +423,12 @@ def add_values(self, vals, overwrite=False) -> None:
425
423
l_samples = len (self ) * len (self .chains )
426
424
l_v = len (v )
427
425
if l_v != l_samples :
428
- warnings .warn ("The length of the values you are trying to "
429
- "add ({}) does not match the number ({}) of "
430
- "total samples in the trace "
431
- "(chains * iterations)" .format (l_v , l_samples ))
426
+ warnings .warn (
427
+ "The length of the values you are trying to "
428
+ "add ({}) does not match the number ({}) of "
429
+ "total samples in the trace "
430
+ "(chains * iterations)" .format (l_v , l_samples )
431
+ )
432
432
433
433
v = np .squeeze (v .reshape (len (chains ), len (self ), - 1 ))
434
434
@@ -457,8 +457,7 @@ def remove_values(self, name):
457
457
chain .vars .remove (va )
458
458
del chain .samples [name ]
459
459
460
- def get_values (self , varname , burn = 0 , thin = 1 , combine = True , chains = None ,
461
- squeeze = True ):
460
+ def get_values (self , varname , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True ):
462
461
"""Get values from traces.
463
462
464
463
Parameters
@@ -485,14 +484,12 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
485
484
chains = self .chains
486
485
varname = get_var_name (varname )
487
486
try :
488
- results = [self ._straces [chain ].get_values (varname , burn , thin )
489
- for chain in chains ]
487
+ results = [self ._straces [chain ].get_values (varname , burn , thin ) for chain in chains ]
490
488
except TypeError : # Single chain passed.
491
489
results = [self ._straces [chains ].get_values (varname , burn , thin )]
492
490
return _squeeze_cat (results , combine , squeeze )
493
491
494
- def get_sampler_stats (self , stat_name , burn = 0 , thin = 1 , combine = True ,
495
- chains = None , squeeze = True ):
492
+ def get_sampler_stats (self , stat_name , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True ):
496
493
"""Get sampler statistics from the trace.
497
494
498
495
Parameters
@@ -520,8 +517,9 @@ def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True,
520
517
except TypeError :
521
518
chains = [chains ]
522
519
523
- results = [self ._straces [chain ].get_sampler_stats (stat_name , None , burn , thin )
524
- for chain in chains ]
520
+ results = [
521
+ self ._straces [chain ].get_sampler_stats (stat_name , None , burn , thin ) for chain in chains
522
+ ]
525
523
return _squeeze_cat (results , combine , squeeze )
526
524
527
525
def _slice (self , slice ):
@@ -582,7 +580,9 @@ def merge_traces(mtraces: List[MultiTrace]) -> MultiTrace:
582
580
base_mtrace = mtraces [0 ]
583
581
chain_len = len (base_mtrace )
584
582
# check base trace
585
- if any (len (st ) != chain_len for _ , st in base_mtrace ._straces .items ()): # pylint: disable=line-too-long
583
+ if any (
584
+ len (st ) != chain_len for _ , st in base_mtrace ._straces .items ()
585
+ ): # pylint: disable=line-too-long
586
586
raise ValueError ("Chains are of different lengths." )
587
587
for new_mtrace in mtraces [1 :]:
588
588
for new_chain , strace in new_mtrace ._straces .items ():
0 commit comments