Skip to content

Commit 56ccabb

Browse files
authored
run black on 15 files (#4110)
1 parent f0b9577 commit 56ccabb

File tree

15 files changed

+446
-413
lines changed

15 files changed

+446
-413
lines changed

pymc3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
def __set_compiler_flags():
3232
# Workarounds for Theano compiler problems on various platforms
3333
import theano
34+
3435
current = theano.config.gcc.cxxflags
3536
theano.config.gcc.cxxflags = f"{current} -Wno-c++11-narrowing"
3637

pymc3/backends/__init__.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,8 @@
136136
from ..backends.sqlite import SQLite
137137
from ..backends.hdf5 import HDF5
138138

139-
_shortcuts = {'text': {'backend': Text,
140-
'name': 'mcmc'},
141-
'sqlite': {'backend': SQLite,
142-
'name': 'mcmc.sqlite'},
143-
'hdf5': {'backend': HDF5,
144-
'name': 'mcmc.hdf5'}}
139+
_shortcuts = {
140+
"text": {"backend": Text, "name": "mcmc"},
141+
"sqlite": {"backend": SQLite, "name": "mcmc.sqlite"},
142+
"hdf5": {"backend": HDF5, "name": "mcmc.hdf5"},
143+
}

pymc3/backends/base.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from .report import SamplerReport, merge_reports
3131
from ..util import get_var_name
3232

33-
logger = logging.getLogger('pymc3')
33+
logger = logging.getLogger("pymc3")
3434

3535

3636
class BackendError(Exception):
@@ -75,10 +75,8 @@ def __init__(self, name, model=None, vars=None, test_point=None):
7575
test_point_.update(test_point)
7676
test_point = test_point_
7777
var_values = list(zip(self.varnames, self.fn(test_point)))
78-
self.var_shapes = {var: value.shape
79-
for var, value in var_values}
80-
self.var_dtypes = {var: value.dtype
81-
for var, value in var_values}
78+
self.var_shapes = {var: value.shape for var, value in var_values}
79+
self.var_dtypes = {var: value.dtype for var, value in var_values}
8280
self.chain = None
8381
self._is_base_setup = False
8482
self.sampler_vars = None
@@ -104,8 +102,7 @@ def _set_sampler_vars(self, sampler_vars):
104102
for stats in sampler_vars:
105103
for key, dtype in stats.items():
106104
if dtypes.setdefault(key, dtype) != dtype:
107-
raise ValueError("Sampler statistic %s appears with "
108-
"different types." % key)
105+
raise ValueError("Sampler statistic %s appears with " "different types." % key)
109106

110107
self.sampler_vars = sampler_vars
111108

@@ -155,7 +152,7 @@ def __getitem__(self, idx):
155152
try:
156153
return self.point(int(idx))
157154
except (ValueError, TypeError): # Passed variable or variable name.
158-
raise ValueError('Can only index with slice or integer')
155+
raise ValueError("Can only index with slice or integer")
159156

160157
def __len__(self):
161158
raise NotImplementedError
@@ -199,13 +196,13 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
199196
if sampler_idx is not None:
200197
return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)
201198

202-
sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
203-
if stat_name in s]
199+
sampler_idxs = [i for i, s in enumerate(self.sampler_vars) if stat_name in s]
204200
if not sampler_idxs:
205201
raise KeyError("Unknown sampler stat %s" % stat_name)
206202

207-
vals = np.stack([self._get_sampler_stats(stat_name, i, burn, thin)
208-
for i in sampler_idxs], axis=-1)
203+
vals = np.stack(
204+
[self._get_sampler_stats(stat_name, i, burn, thin) for i in sampler_idxs], axis=-1
205+
)
209206
if vals.shape[-1] == 1:
210207
return vals[..., 0]
211208
else:
@@ -296,13 +293,12 @@ def __init__(self, straces):
296293

297294
self._report = SamplerReport()
298295
for strace in straces:
299-
if hasattr(strace, '_warnings'):
296+
if hasattr(strace, "_warnings"):
300297
self._report._add_warnings(strace._warnings, strace.chain)
301298

302299
def __repr__(self):
303-
template = '<{}: {} chains, {} iterations, {} variables>'
304-
return template.format(self.__class__.__name__,
305-
self.nchains, len(self), len(self.varnames))
300+
template = "<{}: {} chains, {} iterations, {} variables>"
301+
return template.format(self.__class__.__name__, self.nchains, len(self), len(self.varnames))
306302

307303
@property
308304
def nchains(self):
@@ -339,16 +335,17 @@ def __getitem__(self, idx):
339335
var = get_var_name(var)
340336
if var in self.varnames:
341337
if var in self.stat_names:
342-
warnings.warn("Attribute access on a trace object is ambigous. "
343-
"Sampler statistic and model variable share a name. Use "
344-
"trace.get_values or trace.get_sampler_stats.")
338+
warnings.warn(
339+
"Attribute access on a trace object is ambigous. "
340+
"Sampler statistic and model variable share a name. Use "
341+
"trace.get_values or trace.get_sampler_stats."
342+
)
345343
return self.get_values(var, burn=burn, thin=thin)
346344
if var in self.stat_names:
347345
return self.get_sampler_stats(var, burn=burn, thin=thin)
348346
raise KeyError("Unknown variable %s" % var)
349347

350-
_attrs = {'_straces', 'varnames', 'chains', 'stat_names',
351-
'supports_sampler_stats', '_report'}
348+
_attrs = {"_straces", "varnames", "chains", "stat_names", "supports_sampler_stats", "_report"}
352349

353350
def __getattr__(self, name):
354351
# Avoid infinite recursion when called before __init__
@@ -359,14 +356,15 @@ def __getattr__(self, name):
359356
name = get_var_name(name)
360357
if name in self.varnames:
361358
if name in self.stat_names:
362-
warnings.warn("Attribute access on a trace object is ambigous. "
363-
"Sampler statistic and model variable share a name. Use "
364-
"trace.get_values or trace.get_sampler_stats.")
359+
warnings.warn(
360+
"Attribute access on a trace object is ambigous. "
361+
"Sampler statistic and model variable share a name. Use "
362+
"trace.get_values or trace.get_sampler_stats."
363+
)
365364
return self.get_values(name)
366365
if name in self.stat_names:
367366
return self.get_sampler_stats(name)
368-
raise AttributeError("'{}' object has no attribute '{}'".format(
369-
type(self).__name__, name))
367+
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))
370368

371369
def __len__(self):
372370
chain = self.chains[-1]
@@ -425,10 +423,12 @@ def add_values(self, vals, overwrite=False) -> None:
425423
l_samples = len(self) * len(self.chains)
426424
l_v = len(v)
427425
if l_v != l_samples:
428-
warnings.warn("The length of the values you are trying to "
429-
"add ({}) does not match the number ({}) of "
430-
"total samples in the trace "
431-
"(chains * iterations)".format(l_v, l_samples))
426+
warnings.warn(
427+
"The length of the values you are trying to "
428+
"add ({}) does not match the number ({}) of "
429+
"total samples in the trace "
430+
"(chains * iterations)".format(l_v, l_samples)
431+
)
432432

433433
v = np.squeeze(v.reshape(len(chains), len(self), -1))
434434

@@ -457,8 +457,7 @@ def remove_values(self, name):
457457
chain.vars.remove(va)
458458
del chain.samples[name]
459459

460-
def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
461-
squeeze=True):
460+
def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True):
462461
"""Get values from traces.
463462
464463
Parameters
@@ -485,14 +484,12 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
485484
chains = self.chains
486485
varname = get_var_name(varname)
487486
try:
488-
results = [self._straces[chain].get_values(varname, burn, thin)
489-
for chain in chains]
487+
results = [self._straces[chain].get_values(varname, burn, thin) for chain in chains]
490488
except TypeError: # Single chain passed.
491489
results = [self._straces[chains].get_values(varname, burn, thin)]
492490
return _squeeze_cat(results, combine, squeeze)
493491

494-
def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True,
495-
chains=None, squeeze=True):
492+
def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None, squeeze=True):
496493
"""Get sampler statistics from the trace.
497494
498495
Parameters
@@ -520,8 +517,9 @@ def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True,
520517
except TypeError:
521518
chains = [chains]
522519

523-
results = [self._straces[chain].get_sampler_stats(stat_name, None, burn, thin)
524-
for chain in chains]
520+
results = [
521+
self._straces[chain].get_sampler_stats(stat_name, None, burn, thin) for chain in chains
522+
]
525523
return _squeeze_cat(results, combine, squeeze)
526524

527525
def _slice(self, slice):
@@ -582,7 +580,9 @@ def merge_traces(mtraces: List[MultiTrace]) -> MultiTrace:
582580
base_mtrace = mtraces[0]
583581
chain_len = len(base_mtrace)
584582
# check base trace
585-
if any(len(st) != chain_len for _, st in base_mtrace._straces.items()): # pylint: disable=line-too-long
583+
if any(
584+
len(st) != chain_len for _, st in base_mtrace._straces.items()
585+
): # pylint: disable=line-too-long
586586
raise ValueError("Chains are of different lengths.")
587587
for new_mtrace in mtraces[1:]:
588588
for new_chain, strace in new_mtrace._straces.items():

pymc3/backends/hdf5.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,15 @@
1616
import h5py
1717
from contextlib import contextmanager
1818

19+
1920
@contextmanager
2021
def activator(instance):
2122
if isinstance(instance.hdf5_file, h5py.File):
2223
if instance.hdf5_file.id: # if file is open, keep open
2324
yield
2425
return
2526
# if file is closed/not referenced: open, do job, then close
26-
instance.hdf5_file = h5py.File(instance.name, 'a')
27+
instance.hdf5_file = h5py.File(instance.name, "a")
2728
yield
2829
instance.hdf5_file.close()
2930
return
@@ -43,7 +44,7 @@ class HDF5(base.BaseTrace):
4344
`model.unobserved_RVs` is used.
4445
test_point: dict
4546
use different test point that might be with changed variables shapes
46-
"""
47+
"""
4748

4849
supports_sampler_stats = True
4950

@@ -64,21 +65,21 @@ def activate_file(self):
6465
@property
6566
def samples(self):
6667
g = self.hdf5_file.require_group(str(self.chain))
67-
if 'name' not in g.attrs:
68-
g.attrs['name'] = self.chain
69-
return g.require_group('samples')
68+
if "name" not in g.attrs:
69+
g.attrs["name"] = self.chain
70+
return g.require_group("samples")
7071

7172
@property
7273
def stats(self):
7374
g = self.hdf5_file.require_group(str(self.chain))
74-
if 'name' not in g.attrs:
75-
g.attrs['name'] = self.chain
76-
return g.require_group('stats')
75+
if "name" not in g.attrs:
76+
g.attrs["name"] = self.chain
77+
return g.require_group("stats")
7778

7879
@property
7980
def chains(self):
8081
with self.activate_file:
81-
return [v.attrs['name'] for v in self.hdf5_file.values()]
82+
return [v.attrs["name"] for v in self.hdf5_file.values()]
8283

8384
@property
8485
def is_new_file(self):
@@ -98,19 +99,19 @@ def nchains(self):
9899
@property
99100
def records_stats(self):
100101
with self.activate_file:
101-
return self.hdf5_file.attrs['records_stats']
102+
return self.hdf5_file.attrs["records_stats"]
102103

103104
@records_stats.setter
104105
def records_stats(self, v):
105106
with self.activate_file:
106-
self.hdf5_file.attrs['records_stats'] = bool(v)
107+
self.hdf5_file.attrs["records_stats"] = bool(v)
107108

108109
def _resize(self, n):
109110
for v in self.samples.values():
110111
v.resize(n, axis=0)
111112
for key, group in self.stats.items():
112113
for statds in group.values():
113-
statds.resize((n, ))
114+
statds.resize((n,))
114115

115116
@property
116117
def sampler_vars(self):
@@ -137,10 +138,13 @@ def sampler_vars(self, values):
137138
if not data.keys(): # no pre-recorded stats
138139
for varname, dtype in sampler.items():
139140
if varname not in data:
140-
data.create_dataset(varname, (self.draws,), dtype=dtype, maxshape=(None,))
141+
data.create_dataset(
142+
varname, (self.draws,), dtype=dtype, maxshape=(None,)
143+
)
141144
elif data.keys() != sampler.keys():
142145
raise ValueError(
143-
f"Sampler vars can't change, names incompatible: {data.keys()} != {sampler.keys()}")
146+
f"Sampler vars can't change, names incompatible: {data.keys()} != {sampler.keys()}"
147+
)
144148
self.records_stats = True
145149

146150
def setup(self, draws, chain, sampler_vars=None):
@@ -160,16 +164,18 @@ def setup(self, draws, chain, sampler_vars=None):
160164
with self.activate_file:
161165
for varname, shape in self.var_shapes.items():
162166
if varname not in self.samples:
163-
self.samples.create_dataset(name=varname, shape=(draws, ) + shape,
164-
dtype=self.var_dtypes[varname],
165-
maxshape=(None, ) + shape)
167+
self.samples.create_dataset(
168+
name=varname,
169+
shape=(draws,) + shape,
170+
dtype=self.var_dtypes[varname],
171+
maxshape=(None,) + shape,
172+
)
166173
self.draw_idx = len(self)
167174
self.draws = self.draw_idx + draws
168175
self._set_sampler_vars(sampler_vars)
169176
self._is_base_setup = True
170177
self._resize(self.draws)
171178

172-
173179
def close(self):
174180
with self.activate_file:
175181
if self.draw_idx == self.draws:
@@ -204,8 +210,7 @@ def _slice(self, idx):
204210
start, stop, step = idx.indices(len(self))
205211
sliced = ndarray.NDArray(model=self.model, vars=self.vars)
206212
sliced.chain = self.chain
207-
sliced.samples = {v: self.samples[v][start:stop:step]
208-
for v in self.varnames}
213+
sliced.samples = {v: self.samples[v][start:stop:step] for v in self.varnames}
209214
sliced.draw_idx = (stop - start) // step
210215
return sliced
211216

0 commit comments

Comments
 (0)