21
21
import warnings
22
22
23
23
from abc import ABC
24
- from typing import List , Sequence , Tuple , cast
24
+ from typing import Dict , List , Optional , Sequence , Set , Tuple , Union , cast
25
25
26
26
import numpy as np
27
- import pytensor .tensor as at
28
27
29
28
from pymc .backends .report import SamplerReport
30
29
from pymc .model import modelcontext
@@ -210,18 +209,18 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
210
209
"""Get sampler statistics."""
211
210
raise NotImplementedError ()
212
211
213
- def _slice (self , idx ):
212
+ def _slice (self , idx : Union [ int , slice ] ):
214
213
"""Slice trace object."""
215
214
raise NotImplementedError ()
216
215
217
- def point (self , idx ) :
216
+ def point (self , idx : int ) -> Dict [ str , np . ndarray ] :
218
217
"""Return dictionary of point values at `idx` for current chain
219
218
with variables names as keys.
220
219
"""
221
220
raise NotImplementedError ()
222
221
223
222
@property
224
- def stat_names (self ):
223
+ def stat_names (self ) -> Set [ str ] :
225
224
names = set ()
226
225
for vars in self .sampler_vars or []:
227
226
names .update (vars .keys ())
@@ -280,12 +279,10 @@ class MultiTrace:
280
279
List of variable names in the trace(s)
281
280
"""
282
281
283
- def __init__ (self , straces ):
284
- self ._straces = {}
285
- for strace in straces :
286
- if strace .chain in self ._straces :
287
- raise ValueError ("Chains are not unique." )
288
- self ._straces [strace .chain ] = strace
282
+ def __init__ (self , straces : Sequence [BaseTrace ]):
283
+ if len ({t .chain for t in straces }) != len (straces ):
284
+ raise ValueError ("Chains are not unique." )
285
+ self ._straces = {t .chain : t for t in straces }
289
286
290
287
self ._report = SamplerReport ()
291
288
@@ -294,15 +291,15 @@ def __repr__(self):
294
291
return template .format (self .__class__ .__name__ , self .nchains , len (self ), len (self .varnames ))
295
292
296
293
@property
297
- def nchains (self ):
294
+ def nchains (self ) -> int :
298
295
return len (self ._straces )
299
296
300
297
@property
301
- def chains (self ):
298
+ def chains (self ) -> List [ int ] :
302
299
return list (sorted (self ._straces .keys ()))
303
300
304
301
@property
305
- def report (self ):
302
+ def report (self ) -> SamplerReport :
306
303
return self ._report
307
304
308
305
def __iter__ (self ):
@@ -367,12 +364,12 @@ def __len__(self):
367
364
return len (self ._straces [chain ])
368
365
369
366
@property
370
- def varnames (self ):
367
+ def varnames (self ) -> List [ str ] :
371
368
chain = self .chains [- 1 ]
372
369
return self ._straces [chain ].varnames
373
370
374
371
@property
375
- def stat_names (self ):
372
+ def stat_names (self ) -> Set [ str ] :
376
373
if not self ._straces :
377
374
return set ()
378
375
sampler_vars = [s .sampler_vars for s in self ._straces .values ()]
@@ -386,74 +383,15 @@ def stat_names(self):
386
383
names .update (vars .keys ())
387
384
return names
388
385
389
- def add_values (self , vals , overwrite = False ) -> None :
390
- """Add variables to traces.
391
-
392
- Parameters
393
- ----------
394
- vals: dict (str: array-like)
395
- The keys should be the names of the new variables. The values are expected to be
396
- array-like objects. For traces with more than one chain the length of each value
397
- should match the number of total samples already in the trace `(chains * iterations)`,
398
- otherwise a warning is raised.
399
- overwrite: bool
400
- If `False` (default) a ValueError is raised if the variable already exists.
401
- Change to `True` to overwrite the values of variables
402
-
403
- Returns
404
- -------
405
- None.
406
- """
407
- for k , v in vals .items ():
408
- new_var = 1
409
- if k in self .varnames :
410
- if overwrite :
411
- self .varnames .remove (k )
412
- new_var = 0
413
- else :
414
- raise ValueError (f"Variable name { k } already exists." )
415
-
416
- self .varnames .append (k )
417
-
418
- chains = self ._straces
419
- l_samples = len (self ) * len (self .chains )
420
- l_v = len (v )
421
- if l_v != l_samples :
422
- warnings .warn (
423
- "The length of the values you are trying to "
424
- "add ({}) does not match the number ({}) of "
425
- "total samples in the trace "
426
- "(chains * iterations)" .format (l_v , l_samples )
427
- )
428
-
429
- v = np .squeeze (v .reshape (len (chains ), len (self ), - 1 ))
430
-
431
- for idx , chain in enumerate (chains .values ()):
432
- if new_var :
433
- dummy = at .as_tensor_variable ([], k )
434
- chain .vars .append (dummy )
435
- chain .samples [k ] = v [idx ]
436
-
437
- def remove_values (self , name ):
438
- """remove variables from traces.
439
-
440
- Parameters
441
- ----------
442
- name: str
443
- Name of the variable to remove. Raises KeyError if the variable is not present
444
- """
445
- varnames = self .varnames
446
- if name not in varnames :
447
- raise KeyError (f"Unknown variable { name } " )
448
- self .varnames .remove (name )
449
- chains = self ._straces
450
- for chain in chains .values ():
451
- for va in chain .vars :
452
- if va .name == name :
453
- chain .vars .remove (va )
454
- del chain .samples [name ]
455
-
456
- def get_values (self , varname , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True ):
386
+ def get_values (
387
+ self ,
388
+ varname : str ,
389
+ burn : int = 0 ,
390
+ thin : int = 1 ,
391
+ combine : bool = True ,
392
+ chains : Optional [Union [int , Sequence [int ]]] = None ,
393
+ squeeze : bool = True ,
394
+ ) -> List [np .ndarray ]:
457
395
"""Get values from traces.
458
396
459
397
Parameters
@@ -479,13 +417,20 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze
479
417
if chains is None :
480
418
chains = self .chains
481
419
varname = get_var_name (varname )
482
- try :
483
- results = [self ._straces [chain ].get_values (varname , burn , thin ) for chain in chains ]
484
- except TypeError : # Single chain passed.
485
- results = [self ._straces [chains ].get_values (varname , burn , thin )]
420
+ if isinstance (chains , int ):
421
+ chains = [chains ]
422
+ results = [self ._straces [chain ].get_values (varname , burn , thin ) for chain in chains ]
486
423
return _squeeze_cat (results , combine , squeeze )
487
424
488
- def get_sampler_stats (self , stat_name , burn = 0 , thin = 1 , combine = True , chains = None , squeeze = True ):
425
+ def get_sampler_stats (
426
+ self ,
427
+ stat_name : str ,
428
+ burn : int = 0 ,
429
+ thin : int = 1 ,
430
+ combine : bool = True ,
431
+ chains : Optional [Union [int , Sequence [int ]]] = None ,
432
+ squeeze : bool = True ,
433
+ ):
489
434
"""Get sampler statistics from the trace.
490
435
491
436
Parameters
@@ -508,9 +453,7 @@ def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None
508
453
509
454
if chains is None :
510
455
chains = self .chains
511
- try :
512
- chains = iter (chains )
513
- except TypeError :
456
+ if isinstance (chains , int ):
514
457
chains = [chains ]
515
458
516
459
results = [
@@ -526,7 +469,7 @@ def _slice(self, slice):
526
469
trace ._report = self ._report ._slice (* idxs )
527
470
return trace
528
471
529
- def point (self , idx , chain = None ):
472
+ def point (self , idx : int , chain : Optional [ int ] = None ) -> Dict [ str , np . ndarray ] :
530
473
"""Return a dictionary of point values at `idx`.
531
474
532
475
Parameters
0 commit comments