Skip to content

Formatted Next 15 Files #4150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 61 additions & 91 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def _str_repr(self, name=None, dist=None, formatting="plain"):
else:
return super().__str__()

if name is None and hasattr(self, 'name'):
if name is None and hasattr(self, "name"):
name = self.name
if dist is None and hasattr(self, 'distribution'):
if dist is None and hasattr(self, "distribution"):
dist = self.distribution
return self.distribution._str_repr(name=name, dist=dist, formatting=formatting)

Expand Down Expand Up @@ -123,8 +123,7 @@ def incorporate_methods(source, destination, methods, wrapper=None, override=Fal
for method in methods:
if hasattr(destination, method) and not override:
raise AttributeError(
f"Cannot add method {method!r}"
+ "to destination object as it already exists. "
f"Cannot add method {method!r}" + "to destination object as it already exists. "
"To prevent this error set 'override=True'."
)
if hasattr(source, method):
Expand Down Expand Up @@ -172,12 +171,8 @@ def get_named_nodes_and_relations(graph):
else:
ancestors = {}
descendents = {}
descendents, ancestors = _get_named_nodes_and_relations(
graph, None, ancestors, descendents
)
leaf_dict = {
node.name: node for node, ancestor in ancestors.items() if len(ancestor) == 0
}
descendents, ancestors = _get_named_nodes_and_relations(graph, None, ancestors, descendents)
leaf_dict = {node.name: node for node, ancestor in ancestors.items() if len(ancestor) == 0}
return leaf_dict, descendents, ancestors


Expand Down Expand Up @@ -529,9 +524,7 @@ def tree_contains(self, item):

def __setitem__(self, key, value):
raise NotImplementedError(
"Method is removed as we are not"
" able to determine "
"appropriate logic for it"
"Method is removed as we are not able to determine appropriate logic for it"
)

# Added this because mypy didn't like having __imul__ without __mul__
Expand Down Expand Up @@ -620,7 +613,7 @@ def __init__(
dtype=None,
casting="no",
compute_grads=True,
**kwargs
**kwargs,
):
from .distributions import TensorType

Expand Down Expand Up @@ -695,9 +688,7 @@ def __init__(

inputs = [self._vars_joined]

self._theano_function = theano.function(
inputs, outputs, givens=givens, **kwargs
)
self._theano_function = theano.function(inputs, outputs, givens=givens, **kwargs)

def set_weights(self, values):
if values.shape != (self._n_costs - 1,):
Expand All @@ -713,10 +704,7 @@ def get_extra_values(self):
if not self._extra_are_set:
raise ValueError("Extra values are not set.")

return {
var.name: self._extra_vars_shared[var.name].get_value()
for var in self._extra_vars
}
return {var.name: self._extra_vars_shared[var.name].get_value() for var in self._extra_vars}

def __call__(self, array, grad_out=None, extra_vars=None):
if extra_vars is not None:
Expand All @@ -727,8 +715,7 @@ def __call__(self, array, grad_out=None, extra_vars=None):

if array.shape != (self.size,):
raise ValueError(
"Invalid shape for array. Must be %s but is %s."
% ((self.size,), array.shape)
"Invalid shape for array. Must be {} but is {}.".format((self.size,), array.shape)
)

if grad_out is None:
Expand Down Expand Up @@ -758,13 +745,10 @@ def dict_to_array(self, point):
def array_to_dict(self, array):
"""Convert an array to a dictionary containing the grad_vars."""
if array.shape != (self.size,):
raise ValueError(
f"Array should have shape ({self.size},) but has {array.shape}"
)
raise ValueError(f"Array should have shape ({self.size},) but has {array.shape}")
if array.dtype != self.dtype:
raise ValueError(
"Array has invalid dtype. Should be %s but is %s"
% (self._dtype, self.dtype)
f"Array has invalid dtype. Should be {self._dtype} but is {self.dtype}"
)
point = {}
for varmap in self._ordering.vmap:
Expand Down Expand Up @@ -988,17 +972,15 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
for var in grad_vars:
if var.dtype not in continuous_types:
raise ValueError(
"Can only compute the gradient of " "continuous types: %s" % var
"Can only compute the gradient of continuous types: %s" % var
)

if tempered:
with self:
free_RVs_logp = tt.sum([
tt.sum(var.logpt) for var in self.free_RVs + self.potentials
])
observed_RVs_logp = tt.sum([
tt.sum(var.logpt) for var in self.observed_RVs
])
free_RVs_logp = tt.sum(
[tt.sum(var.logpt) for var in self.free_RVs + self.potentials]
)
observed_RVs_logp = tt.sum([tt.sum(var.logpt) for var in self.observed_RVs])

costs = [free_RVs_logp, observed_RVs_logp]
else:
Expand Down Expand Up @@ -1038,7 +1020,7 @@ def logp_nojact(self):
@property
def varlogpt(self):
"""Theano scalar of log-probability of the unobserved random variables
(excluding deterministic)."""
(excluding deterministic)."""
with self:
factors = [var.logpt for var in self.free_RVs]
return tt.sum(factors)
Expand Down Expand Up @@ -1110,9 +1092,7 @@ def add_coords(self, coords):
)
if name in self.coords:
if not coords[name].equals(self.coords[name]):
raise ValueError(
"Duplicate and incompatiple coordinate: %s." % name
)
raise ValueError("Duplicate and incompatiple coordinate: %s." % name)
else:
self.coords[name] = coords[name]

Expand Down Expand Up @@ -1141,9 +1121,7 @@ def Var(self, name, dist, data=None, total_size=None, dims=None):
if data is None:
if getattr(dist, "transform", None) is None:
with self:
var = FreeRV(
name=name, distribution=dist, total_size=total_size, model=self
)
var = FreeRV(name=name, distribution=dist, total_size=total_size, model=self)
self.free_RVs.append(var)
else:
with self:
Expand Down Expand Up @@ -1218,8 +1196,7 @@ def prefix(self):
return "%s_" % self.name if self.name else ""

def name_for(self, name):
"""Checks if name has prefix and adds if needed
"""
"""Checks if name has prefix and adds if needed"""
if self.prefix:
if not name.startswith(self.prefix):
return f"{self.prefix}{name}"
Expand All @@ -1229,8 +1206,7 @@ def name_for(self, name):
return name

def name_of(self, name):
"""Checks if name has prefix and deletes if needed
"""
"""Checks if name has prefix and deletes if needed"""
if not self.prefix or not name:
return name
elif name.startswith(self.prefix):
Expand Down Expand Up @@ -1269,7 +1245,7 @@ def makefn(self, outs, mode=None, *args, **kwargs):
accept_inplace=True,
mode=mode,
*args,
**kwargs
**kwargs,
)

def fn(self, outs, mode=None, *args, **kwargs):
Expand Down Expand Up @@ -1391,10 +1367,7 @@ def check_test_point(self, test_point=None, round_vals=2):
test_point = self.test_point

return Series(
{
RV.name: np.round(RV.logp(self.test_point), round_vals)
for RV in self.basic_RVs
},
{RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs},
name="Log-probability of test_point",
)

Expand All @@ -1403,23 +1376,31 @@ def _str_repr(self, formatting="plain", **kwargs):

if formatting == "latex":
rv_reprs = [rv.__latex__() for rv in all_rv]
rv_reprs = [rv_repr.replace(r"\sim", r"&\sim &").strip("$")
for rv_repr in rv_reprs if rv_repr is not None]
rv_reprs = [
rv_repr.replace(r"\sim", r"&\sim &").strip("$")
for rv_repr in rv_reprs
if rv_repr is not None
]
return r"""$$
\begin{{array}}{{rcl}}
{}
\end{{array}}
$$""".format(
"\\\\".join(rv_reprs))
"\\\\".join(rv_reprs)
)
else:
rv_reprs = [rv.__str__() for rv in all_rv]
rv_reprs = [rv_repr for rv_repr in rv_reprs if not 'TransformedDistribution()' in rv_repr]
rv_reprs = [
rv_repr for rv_repr in rv_reprs if not "TransformedDistribution()" in rv_repr
]
# align vars on their ~
names = [s[:s.index('~')-1] for s in rv_reprs]
distrs = [s[s.index('~')+2:] for s in rv_reprs]
names = [s[: s.index("~") - 1] for s in rv_reprs]
distrs = [s[s.index("~") + 2 :] for s in rv_reprs]
maxlen = str(max(len(x) for x in names))
rv_reprs = [('{name:>' + maxlen + '} ~ {distr}').format(name=n, distr=d)
for n, d in zip(names, distrs)]
rv_reprs = [
("{name:>" + maxlen + "} ~ {distr}").format(name=n, distr=d)
for n, d in zip(names, distrs)
]
return "\n".join(rv_reprs)

def __str__(self, **kwargs):
Expand Down Expand Up @@ -1537,8 +1518,9 @@ def Point(*args, **kwargs):
except Exception as e:
raise TypeError(f"can't turn {args} and {kwargs} into a dict. {e}")
return {
get_var_name(k): np.array(v) for k, v in d.items()
if get_var_name(k) in map(get_var_name, model.vars)
get_var_name(k): np.array(v)
for k, v in d.items()
if get_var_name(k) in map(get_var_name, model.vars)
}


Expand Down Expand Up @@ -1593,11 +1575,7 @@ def _get_scaling(total_size, shape, ndim):
denom = 1
coef = floatX(total_size) / floatX(denom)
elif isinstance(total_size, (list, tuple)):
if not all(
isinstance(i, int)
for i in total_size
if (i is not Ellipsis and i is not None)
):
if not all(isinstance(i, int) for i in total_size if (i is not Ellipsis and i is not None)):
raise TypeError(
"Unrecognized `total_size` type, expected "
"int or list of ints, got %r" % total_size
Expand Down Expand Up @@ -1625,16 +1603,13 @@ def _get_scaling(total_size, shape, ndim):
else:
shp_end = np.asarray([])
shp_begin = shape[: len(begin)]
begin_coef = [
floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None
]
begin_coef = [floatX(t) / shp_begin[i] for i, t in enumerate(begin) if t is not None]
end_coef = [floatX(t) / shp_end[i] for i, t in enumerate(end) if t is not None]
coefs = begin_coef + end_coef
coef = tt.prod(coefs)
else:
raise TypeError(
"Unrecognized `total_size` type, expected "
"int or list of ints, got %r" % total_size
"Unrecognized `total_size` type, expected int or list of ints, got %r" % total_size
)
return tt.as_tensor(floatX(coef))

Expand Down Expand Up @@ -1753,9 +1728,7 @@ def as_tensor(data, name, model, distribution):
testval=testval,
parent_dist=distribution,
)
missing_values = FreeRV(
name=name + "_missing", distribution=fakedist, model=model
)
missing_values = FreeRV(name=name + "_missing", distribution=fakedist, model=model)
constant = tt.as_tensor_variable(data.filled())

dataTensor = tt.set_subtensor(constant[data.mask.nonzero()], missing_values)
Expand Down Expand Up @@ -1854,14 +1827,11 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
"""
self.name = name
self.data = {
name: as_tensor(data, name, model, distribution)
for name, data in data.items()
name: as_tensor(data, name, model, distribution) for name, data in data.items()
}

self.missing_values = [
datum.missing_values
for datum in self.data.values()
if datum.missing_values is not None
datum.missing_values for datum in self.data.values() if datum.missing_values is not None
]
self.logp_elemwiset = distribution.logp(**self.data)
# The logp might need scaling in minibatches.
Expand All @@ -1871,9 +1841,7 @@ def __init__(self, name, data, distribution, total_size=None, model=None):
self.total_size = total_size
self.model = model
self.distribution = distribution
self.scaling = _get_scaling(
total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim
)
self.scaling = _get_scaling(total_size, self.logp_elemwiset.shape, self.logp_elemwiset.ndim)

# Make hashable by id for draw_values
def __hash__(self):
Expand All @@ -1888,7 +1856,7 @@ def __ne__(self, other):
return not self == other


def _walk_up_rv(rv, formatting='plain'):
def _walk_up_rv(rv, formatting="plain"):
"""Walk up theano graph to get inputs for deterministic RV."""
all_rvs = []
parents = list(itertools.chain(*[j.inputs for j in rv.get_parents()]))
Expand All @@ -1903,21 +1871,23 @@ def _walk_up_rv(rv, formatting='plain'):


class DeterministicWrapper(tt.TensorVariable):
def _str_repr(self, formatting='plain'):
if formatting == 'latex':
def _str_repr(self, formatting="plain"):
if formatting == "latex":
return r"$\text{{{name}}} \sim \text{{Deterministic}}({args})$".format(
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting)))
name=self.name, args=r",~".join(_walk_up_rv(self, formatting=formatting))
)
else:
return "{name} ~ Deterministic({args})".format(
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting)))
name=self.name, args=", ".join(_walk_up_rv(self, formatting=formatting))
)

def _repr_latex_(self):
return self._str_repr(formatting='latex')
return self._str_repr(formatting="latex")

__latex__ = _repr_latex_

def __str__(self):
return self._str_repr(formatting='plain')
return self._str_repr(formatting="plain")


def Deterministic(name, var, model=None, dims=None):
Expand All @@ -1936,7 +1906,7 @@ def Deterministic(name, var, model=None, dims=None):
var = var.copy(model.name_for(name))
model.deterministics.append(var)
model.add_random_variable(var, dims)
var.__class__ = DeterministicWrapper # adds str and latex functionality
var.__class__ = DeterministicWrapper # adds str and latex functionality

return var

Expand Down Expand Up @@ -2030,7 +2000,7 @@ def as_iterargs(data):

def all_continuous(vars):
"""Check that vars not include discrete variables, excepting
ObservedRVs. """
ObservedRVs."""
vars_ = [var for var in vars if not isinstance(var, pm.model.ObservedRV)]
if any([var.dtype in pm.discrete_types for var in vars_]):
return False
Expand Down
Loading