Skip to content

Commit 6ab0c03

Browse files
michaelosthegetwiecki
authored andcommitted
Fix BaseTrace and MultiTrace typing; remove add_values/remove_values
The `add_values`/`remove_values` methods accessed attributes that are not present on `BaseTrace` but only on `NDarray`, therefore violating the signature.
1 parent 6c4d4eb commit 6ab0c03

File tree

2 files changed

+36
-112
lines changed

2 files changed

+36
-112
lines changed

pymc/backends/base.py

Lines changed: 36 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,9 @@
2121
import warnings
2222

2323
from abc import ABC
24-
from typing import List, Sequence, Tuple, cast
24+
from typing import Dict, List, Optional, Sequence, Set, Tuple, Union, cast
2525

2626
import numpy as np
27-
import pytensor.tensor as at
2827

2928
from pymc.backends.report import SamplerReport
3029
from pymc.model import modelcontext
@@ -210,18 +209,18 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
210209
"""Get sampler statistics."""
211210
raise NotImplementedError()
212211

213-
def _slice(self, idx):
212+
def _slice(self, idx: Union[int, slice]):
214213
"""Slice trace object."""
215214
raise NotImplementedError()
216215

217-
def point(self, idx):
216+
def point(self, idx: int) -> Dict[str, np.ndarray]:
218217
"""Return dictionary of point values at `idx` for current chain
219218
with variables names as keys.
220219
"""
221220
raise NotImplementedError()
222221

223222
@property
224-
def stat_names(self):
223+
def stat_names(self) -> Set[str]:
225224
names = set()
226225
for vars in self.sampler_vars or []:
227226
names.update(vars.keys())
@@ -280,12 +279,10 @@ class MultiTrace:
280279
List of variable names in the trace(s)
281280
"""
282281

283-
def __init__(self, straces):
284-
self._straces = {}
285-
for strace in straces:
286-
if strace.chain in self._straces:
287-
raise ValueError("Chains are not unique.")
288-
self._straces[strace.chain] = strace
282+
def __init__(self, straces: Sequence[BaseTrace]):
283+
if len({t.chain for t in straces}) != len(straces):
284+
raise ValueError("Chains are not unique.")
285+
self._straces = {t.chain: t for t in straces}
289286

290287
self._report = SamplerReport()
291288

@@ -294,15 +291,15 @@ def __repr__(self):
294291
return template.format(self.__class__.__name__, self.nchains, len(self), len(self.varnames))
295292

296293
@property
297-
def nchains(self):
294+
def nchains(self) -> int:
298295
return len(self._straces)
299296

300297
@property
301-
def chains(self):
298+
def chains(self) -> List[int]:
302299
return list(sorted(self._straces.keys()))
303300

304301
@property
305-
def report(self):
302+
def report(self) -> SamplerReport:
306303
return self._report
307304

308305
def __iter__(self):
@@ -367,12 +364,12 @@ def __len__(self):
367364
return len(self._straces[chain])
368365

369366
@property
370-
def varnames(self):
367+
def varnames(self) -> List[str]:
371368
chain = self.chains[-1]
372369
return self._straces[chain].varnames
373370

374371
@property
375-
def stat_names(self):
372+
def stat_names(self) -> Set[str]:
376373
if not self._straces:
377374
return set()
378375
sampler_vars = [s.sampler_vars for s in self._straces.values()]
@@ -386,74 +383,15 @@ def stat_names(self):
386383
names.update(vars.keys())
387384
return names
388385

389-
def add_values(self, vals, overwrite=False) -> None:
390-
"""Add variables to traces.
391-
392-
Parameters
393-
----------
394-
vals: dict (str: array-like)
395-
The keys should be the names of the new variables. The values are expected to be
396-
array-like objects. For traces with more than one chain the length of each value
397-
should match the number of total samples already in the trace `(chains * iterations)`,
398-
otherwise a warning is raised.
399-
overwrite: bool
400-
If `False` (default) a ValueError is raised if the variable already exists.
401-
Change to `True` to overwrite the values of variables
402-
403-
Returns
404-
-------
405-
None.
406-
"""
407-
for k, v in vals.items():
408-
new_var = 1
409-
if k in self.varnames:
410-
if overwrite:
411-
self.varnames.remove(k)
412-
new_var = 0
413-
else:
414-
raise ValueError(f"Variable name {k} already exists.")
415-
416-
self.varnames.append(k)
417-
418-
chains = self._straces
419-
l_samples = len(self) * len(self.chains)
420-
l_v = len(v)
421-
if l_v != l_samples:
422-
warnings.warn(
423-
"The length of the values you are trying to "
424-
"add ({}) does not match the number ({}) of "
425-
"total samples in the trace "
426-
"(chains * iterations)".format(l_v, l_samples)
427-
)
428-
429-
v = np.squeeze(v.reshape(len(chains), len(self), -1))
430-
431-
for idx, chain in enumerate(chains.values()):
432-
if new_var:
433-
dummy = at.as_tensor_variable([], k)
434-
chain.vars.append(dummy)
435-
chain.samples[k] = v[idx]
436-
437-
def remove_values(self, name):
438-
"""remove variables from traces.
439-
440-
Parameters
441-
----------
442-
name: str
443-
Name of the variable to remove. Raises KeyError if the variable is not present
444-
"""
445-
varnames = self.varnames
446-
if name not in varnames:
447-
raise KeyError(f"Unknown variable {name}")
448-
self.varnames.remove(name)
449-
chains = self._straces
450-
for chain in chains.values():
451-
for va in chain.vars:
452-
if va.name == name:
453-
chain.vars.remove(va)
454-
del chain.samples[name]
455-
456-
def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True):
386+
def get_values(
387+
self,
388+
varname: str,
389+
burn: int = 0,
390+
thin: int = 1,
391+
combine: bool = True,
392+
chains: Optional[Union[int, Sequence[int]]] = None,
393+
squeeze: bool = True,
394+
) -> List[np.ndarray]:
457395
"""Get values from traces.
458396
459397
Parameters
@@ -479,13 +417,20 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze
479417
if chains is None:
480418
chains = self.chains
481419
varname = get_var_name(varname)
482-
try:
483-
results = [self._straces[chain].get_values(varname, burn, thin) for chain in chains]
484-
except TypeError: # Single chain passed.
485-
results = [self._straces[chains].get_values(varname, burn, thin)]
420+
if isinstance(chains, int):
421+
chains = [chains]
422+
results = [self._straces[chain].get_values(varname, burn, thin) for chain in chains]
486423
return _squeeze_cat(results, combine, squeeze)
487424

488-
def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None, squeeze=True):
425+
def get_sampler_stats(
426+
self,
427+
stat_name: str,
428+
burn: int = 0,
429+
thin: int = 1,
430+
combine: bool = True,
431+
chains: Optional[Union[int, Sequence[int]]] = None,
432+
squeeze: bool = True,
433+
):
489434
"""Get sampler statistics from the trace.
490435
491436
Parameters
@@ -508,9 +453,7 @@ def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None
508453

509454
if chains is None:
510455
chains = self.chains
511-
try:
512-
chains = iter(chains)
513-
except TypeError:
456+
if isinstance(chains, int):
514457
chains = [chains]
515458

516459
results = [
@@ -526,7 +469,7 @@ def _slice(self, slice):
526469
trace._report = self._report._slice(*idxs)
527470
return trace
528471

529-
def point(self, idx, chain=None):
472+
def point(self, idx: int, chain: Optional[int] = None) -> Dict[str, np.ndarray]:
530473
"""Return a dictionary of point values at `idx`.
531474
532475
Parameters

pymc/tests/backends/test_ndarray.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -124,25 +124,6 @@ def test_multitrace_nonunique(self):
124124
base.MultiTrace([self.strace0, self.strace1])
125125

126126

127-
class TestMultiTrace_add_remove_values(bf.ModelBackendSampledTestCase):
128-
name = None
129-
backend = ndarray.NDArray
130-
shape = ()
131-
132-
def test_add_values(self):
133-
mtrace = self.mtrace
134-
orig_varnames = list(mtrace.varnames)
135-
name = "new_var"
136-
vals = mtrace[orig_varnames[0]]
137-
mtrace.add_values({name: vals})
138-
assert len(orig_varnames) == len(mtrace.varnames) - 1
139-
assert name in mtrace.varnames
140-
assert np.all(mtrace[orig_varnames[0]] == mtrace[name])
141-
mtrace.remove_values(name)
142-
assert len(orig_varnames) == len(mtrace.varnames)
143-
assert name not in mtrace.varnames
144-
145-
146127
class TestSqueezeCat:
147128
def setup_method(self):
148129
self.x = np.arange(10)

0 commit comments

Comments
 (0)