Skip to content

Commit 6b765e4

Browse files
authored
Formatted Next 15 Files (#4150)
* Formatted Next 15 Files * Run command pre-commit run * removed quotes * Update test_examples.py Changes Done Sir used # fmt: off before and # fmt:on after the array * Update test_examples.py Changes updated
1 parent 7429b4b commit 6b765e4

15 files changed

+435
-461
lines changed

pymc3/model.py

Lines changed: 61 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
7070
else:
7171
return super().__str__()
7272

73-
if name is None and hasattr(self, 'name'):
73+
if name is None and hasattr(self, "name"):
7474
name = self.name
75-
if dist is None and hasattr(self, 'distribution'):
75+
if dist is None and hasattr(self, "distribution"):
7676
dist = self.distribution
7777
return self.distribution._str_repr(name=name, dist=dist, formatting=formatting)
7878

@@ -123,8 +123,7 @@ def incorporate_methods(source, destination, methods, wrapper=None, override=Fal
123123
for method in methods:
124124
if hasattr(destination, method) and not override:
125125
raise AttributeError(
126-
f"Cannot add method {method!r}"
127-
+ "to destination object as it already exists. "
126+
f"Cannot add method {method!r}" + "to destination object as it already exists. "
128127
"To prevent this error set 'override=True'."
129128
)
130129
if hasattr(source, method):
@@ -172,12 +171,8 @@ def get_named_nodes_and_relations(graph):
172171
else:
173172
ancestors = {}
174173
descendents = {}
175-
descendents, ancestors = _get_named_nodes_and_relations(
176-
graph, None, ancestors, descendents
177-
)
178-
leaf_dict = {
179-
node.name: node for node, ancestor in ancestors.items() if len(ancestor) == 0
180-
}
174+
descendents, ancestors = _get_named_nodes_and_relations(graph, None, ancestors, descendents)
175+
leaf_dict = {node.name: node for node, ancestor in ancestors.items() if len(ancestor) == 0}
181176
return leaf_dict, descendents, ancestors
182177

183178

@@ -529,9 +524,7 @@ def tree_contains(self, item):
529524

530525
def __setitem__(self, key, value):
531526
raise NotImplementedError(
532-
"Method is removed as we are not"
533-
" able to determine "
534-
"appropriate logic for it"
527+
"Method is removed as we are not able to determine appropriate logic for it"
535528
)
536529

537530
# Added this because mypy didn't like having __imul__ without __mul__
@@ -620,7 +613,7 @@ def __init__(
620613
dtype=None,
621614
casting="no",
622615
compute_grads=True,
623-
**kwargs
616+
**kwargs,
624617
):
625618
from .distributions import TensorType
626619

@@ -695,9 +688,7 @@ def __init__(
695688

696689
inputs = [self._vars_joined]
697690

698-
self._theano_function = theano.function(
699-
inputs, outputs, givens=givens, **kwargs
700-
)
691+
self._theano_function = theano.function(inputs, outputs, givens=givens, **kwargs)
701692

702693
def set_weights(self, values):
703694
if values.shape != (self._n_costs - 1,):
@@ -713,10 +704,7 @@ def get_extra_values(self):
713704
if not self._extra_are_set:
714705
raise ValueError("Extra values are not set.")
715706

716-
return {
717-
var.name: self._extra_vars_shared[var.name].get_value()
718-
for var in self._extra_vars
719-
}
707+
return {var.name: self._extra_vars_shared[var.name].get_value() for var in self._extra_vars}
720708

721709
def __call__(self, array, grad_out=None, extra_vars=None):
722710
if extra_vars is not None:
@@ -727,8 +715,7 @@ def __call__(self, array, grad_out=None, extra_vars=None):
727715

728716
if array.shape != (self.size,):
729717
raise ValueError(
730-
"Invalid shape for array. Must be %s but is %s."
731-
% ((self.size,), array.shape)
718+
"Invalid shape for array. Must be {} but is {}.".format((self.size,), array.shape)
732719
)
733720

734721
if grad_out is None:
@@ -758,13 +745,10 @@ def dict_to_array(self, point):
758745
def array_to_dict(self, array):
759746
"""Convert an array to a dictionary containing the grad_vars."""
760747
if array.shape != (self.size,):
761-
raise ValueError(
762-
f"Array should have shape ({self.size},) but has {array.shape}"
763-
)
748+
raise ValueError(f"Array should have shape ({self.size},) but has {array.shape}")
764749
if array.dtype != self.dtype:
765750
raise ValueError(
766-
"Array has invalid dtype. Should be %s but is %s"
767-
% (self._dtype, self.dtype)
751+
f"Array has invalid dtype. Should be {self._dtype} but is {self.dtype}"
768752
)
769753
point = {}
770754
for varmap in self._ordering.vmap:
@@ -988,17 +972,15 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
988972
for var in grad_vars:
989973
if var.dtype not in continuous_types:
990974
raise ValueError(
991-
"Can only compute the gradient of " "continuous types: %s" % var
975+
"Can only compute the gradient of continuous types: %s" % var
992976
)
993977

994978
if tempered:
995979
with self:
996-
free_RVs_logp = tt.sum([
997-
tt.sum(var.logpt) for var in self.free_RVs + self.potentials
998-
])
999-
observed_RVs_logp = tt.sum([
1000-
tt.sum(var.logpt) for var in self.observed_RVs
1001-
])
980+
free_RVs_logp = tt.sum(
981+
[tt.sum(var.logpt) for var in self.free_RVs + self.potentials]
982+
)
983+
observed_RVs_logp = tt.sum([tt.sum(var.logpt) for var in self.observed_RVs])
1002984

1003985
costs = [free_RVs_logp, observed_RVs_logp]
1004986
else:
@@ -1038,7 +1020,7 @@ def logp_nojact(self):
10381020
@property
10391021
def varlogpt(self):
10401022
"""Theano scalar of log-probability of the unobserved random variables
1041-
(excluding deterministic)."""
1023+
(excluding deterministic)."""
10421024
with self:
10431025
factors = [var.logpt for var in self.free_RVs]
10441026
return tt.sum(factors)
@@ -1110,9 +1092,7 @@ def add_coords(self, coords):
11101092
)
11111093
if name in self.coords:
11121094
if not coords[name].equals(self.coords[name]):
1113-
raise ValueError(
1114-
"Duplicate and incompatiple coordinate: %s." % name
1115-
)
1095+
raise ValueError("Duplicate and incompatiple coordinate: %s." % name)
11161096
else:
11171097
self.coords[name] = coords[name]
11181098

@@ -1141,9 +1121,7 @@ def Var(self, name, dist, data=None, total_size=None, dims=None):
11411121
if data is None:
11421122
if getattr(dist, "transform", None) is None:
11431123
with self:
1144-
var = FreeRV(
1145-
name=name, distribution=dist, total_size=total_size, model=self
1146-
)
1124+
var = FreeRV(name=name, distribution=dist, total_size=total_size, model=self)
11471125
self.free_RVs.append(var)
11481126
else:
11491127
with self:
@@ -1218,8 +1196,7 @@ def prefix(self):
12181196
return "%s_" % self.name if self.name else ""
12191197

12201198
def name_for(self, name):
1221-
"""Checks if name has prefix and adds if needed
1222-
"""
1199+
"""Checks if name has prefix and adds if needed"""
12231200
if self.prefix:
12241201
if not name.startswith(self.prefix):
12251202
return f"{self.prefix}{name}"
@@ -1229,8 +1206,7 @@ def name_for(self, name):
12291206
return name
12301207

12311208
def name_of(self, name):
1232-
"""Checks if name has prefix and deletes if needed
1233-
"""
1209+
"""Checks if name has prefix and deletes if needed"""
12341210
if not self.prefix or not name:
12351211
return name
12361212
elif name.startswith(self.prefix):
@@ -1269,7 +1245,7 @@ def makefn(self, outs, mode=None, *args, **kwargs):
12691245
accept_inplace=True,
12701246
mode=mode,
12711247
*args,
1272-
**kwargs
1248+
**kwargs,
12731249
)
12741250

12751251
def fn(self, outs, mode=None, *args, **kwargs):
@@ -1391,10 +1367,7 @@ def check_test_point(self, test_point=None, round_vals=2):
13911367
test_point = self.test_point
13921368

13931369
return Series(
1394-
{
1395-
RV.name: np.round(RV.logp(self.test_point), round_vals)
1396-
for RV in self.basic_RVs
1397-
},
1370+
{RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs},
13981371
name="Log-probability of test_point",
13991372
)
14001373

@@ -1403,23 +1376,31 @@ def _str_repr(self, formatting="plain", **kwargs):
14031376

14041377
if formatting == "latex":
14051378
rv_reprs = [rv.__latex__() for rv in all_rv]
1406-
rv_reprs = [rv_repr.replace(r"\sim", r"&\sim &").strip("$")
1407-
for rv_repr in rv_reprs if rv_repr is not None]
1379+
rv_reprs = [
1380+
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
1381+
for rv_repr in rv_reprs
1382+
if rv_repr is not None
1383+
]
14081384
return r"""$$
14091385
\begin{{array}}{{rcl}}
14101386
{}
14111387
\end{{array}}
14121388
$$""".format(
1413-
"\\\\".join(rv_reprs))
1389+
"\\\\".join(rv_reprs)
1390+
)
14141391
else:
14151392
rv_reprs = [rv.__str__() for rv in all_rv]
1416-
rv_reprs = [rv_repr for rv_repr in rv_reprs if not 'TransformedDistribution()' in rv_repr]
1393+
rv_reprs = [
1394+
rv_repr for rv_repr in rv_reprs if not "TransformedDistribution()" in rv_repr
1395+
]
14171396
# align vars on their ~
1418-
names = [s[:s.index('~')-1] for s in rv_reprs]
1419-
distrs = [s[s.index('~')+2:] for s in rv_reprs]
1397+
names = [s[: s.index("~") - 1] for s in rv_reprs]
1398+
distrs = [s[s.index("~") + 2 :] for s in rv_reprs]
14201399
maxlen = str(max(len(x) for x in names))
1421-
rv_reprs = [('{name:>' + maxlen + '} ~ {distr}').format(name=n, distr=d)
1422-
for n, d in zip(names, distrs)]
1400+
rv_reprs = [
1401+
("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d)
1402+
for n, d in zip(names, distrs)
1403+
]
14231404
return "\n".join(rv_reprs)
14241405

14251406
def __str__(self, **kwargs):
@@ -1537,8 +1518,9 @@ def Point(*args, **kwargs):
15371518
except Exception as e:
15381519
raise TypeError(f"can't turn {args} and {kwargs} into a dict. {e}")
15391520
return {
1540-
get_var_name(k): np.array(v) for k, v in d.items()
1541-
if get_var_name(k) in map(get_var_name, model.vars)
1521+
get_var_name(k): np.array(v)
1522+
for k, v in d.items()
1523+
if get_var_name(k) in map(get_var_name, model.vars)
15421524
}
15431525

15441526

@@ -1593,11 +1575,7 @@ def _get_scaling(total_size, shape, ndim):
15931575
denom = 1
15941576
coef = floatX(total_size) / floatX(denom)
15951577
elif isinstance(total_size, (list, tuple)):
1596-
if not all(
1597-
isinstance(i, int)
1598-
for i in total_size
1599-
if (i is not Ellipsis and i is not None)
1600-
):
1578+
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
16011579
raise TypeError(
16021580
"Unrecognized `total_size` type, expected "
16031581
"int or list of ints, got %r" % total_size
@@ -1625,16 +1603,13 @@ def _get_scaling(total_size, shape, ndim):
16251603
else:
16261604
shp_end = np.asarray([])
16271605
shp_begin = shape[: len(begin)]
1628-
begin_coef = [
1629-
floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None
1630-
]
1606+
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
16311607
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
16321608
coefs = begin_coef + end_coef
16331609
coef = tt.prod(coefs)
16341610
else:
16351611
raise TypeError(
1636-
"Unrecognized `total_size` type, expected "
1637-
"int or list of ints, got %r" % total_size
1612+
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
16381613
)
16391614
return tt.as_tensor(floatX(coef))
16401615

@@ -1753,9 +1728,7 @@ def as_tensor(data, name, model, distribution):
17531728
testval=testval,
17541729
parent_dist=distribution,
17551730
)
1756-
missing_values = FreeRV(
1757-
name=name + "_missing", distribution=fakedist, model=model
1758-
)
1731+
missing_values = FreeRV(name=name + "_missing", distribution=fakedist, model=model)
17591732
constant = tt.as_tensor_variable(data.filled())
17601733

17611734
dataTensor = tt.set_subtensor(constant[data.mask.nonzero()], missing_values)
@@ -1854,14 +1827,11 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
18541827
"""
18551828
self.name = name
18561829
self.data = {
1857-
name: as_tensor(data, name, model, distribution)
1858-
for name, data in data.items()
1830+
name: as_tensor(data, name, model, distribution) for name, data in data.items()
18591831
}
18601832

18611833
self.missing_values = [
1862-
datum.missing_values
1863-
for datum in self.data.values()
1864-
if datum.missing_values is not None
1834+
datum.missing_values for datum in self.data.values() if datum.missing_values is not None
18651835
]
18661836
self.logp_elemwiset = distribution.logp(**self.data)
18671837
# The logp might need scaling in minibatches.
@@ -1871,9 +1841,7 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
18711841
self.total_size = total_size
18721842
self.model = model
18731843
self.distribution = distribution
1874-
self.scaling = _get_scaling(
1875-
total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim
1876-
)
1844+
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)
18771845

18781846
# Make hashable by id for draw_values
18791847
def __hash__(self):
@@ -1888,7 +1856,7 @@ def __ne__(self, other):
18881856
return not self == other
18891857

18901858

1891-
def _walk_up_rv(rv, formatting='plain'):
1859+
def _walk_up_rv(rv, formatting="plain"):
18921860
"""Walk up theano graph to get inputs for deterministic RV."""
18931861
all_rvs = []
18941862
parents = list(itertools.chain(*[j.inputs for j in rv.get_parents()]))
@@ -1903,21 +1871,23 @@ def _walk_up_rv(rv, formatting='plain'):
19031871

19041872

19051873
class DeterministicWrapper(tt.TensorVariable):
1906-
def _str_repr(self, formatting='plain'):
1907-
if formatting == 'latex':
1874+
def _str_repr(self, formatting="plain"):
1875+
if formatting == "latex":
19081876
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
1909-
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting)))
1877+
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
1878+
)
19101879
else:
19111880
return "{name} ~ Deterministic({args})".format(
1912-
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting)))
1881+
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting))
1882+
)
19131883

19141884
def _repr_latex_(self):
1915-
return self._str_repr(formatting='latex')
1885+
return self._str_repr(formatting="latex")
19161886

19171887
__latex__ = _repr_latex_
19181888

19191889
def __str__(self):
1920-
return self._str_repr(formatting='plain')
1890+
return self._str_repr(formatting="plain")
19211891

19221892

19231893
def Deterministic(name, var, model=None, dims=None):
@@ -1936,7 +1906,7 @@ def Deterministic(name, var, model=None, dims=None):
19361906
var = var.copy(model.name_for(name))
19371907
model.deterministics.append(var)
19381908
model.add_random_variable(var, dims)
1939-
var.__class__ = DeterministicWrapper # adds str and latex functionality
1909+
var.__class__ = DeterministicWrapper # adds str and latex functionality
19401910

19411911
return var
19421912

@@ -2030,7 +2000,7 @@ def as_iterargs(data):
20302000

20312001
def all_continuous(vars):
20322002
"""Check that vars not include discrete variables, excepting
2033-
ObservedRVs. """
2003+
ObservedRVs."""
20342004
vars_ = [var for var in vars if not isinstance(var, pm.model.ObservedRV)]
20352005
if any([var.dtype in pm.discrete_types for var in vars_]):
20362006
return False

0 commit comments

Comments
 (0)