Skip to content

Remove initval from dist() API and add docstrings, type hints, tests #4913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
# → pytest will run only these files
- |
--ignore=pymc3/tests/test_distributions_timeseries.py
--ignore=pymc3/tests/test_initvals.py
--ignore=pymc3/tests/test_mixture.py
--ignore=pymc3/tests/test_model_graph.py
--ignore=pymc3/tests/test_modelcontext.py
Expand Down Expand Up @@ -60,7 +61,9 @@ jobs:
--ignore=pymc3/tests/test_distributions_random.py
--ignore=pymc3/tests/test_idata_conversion.py

- pymc3/tests/test_distributions.py
- |
pymc3/tests/test_initvals.py
pymc3/tests/test_distributions.py

- |
pymc3/tests/test_modelcontext.py
Expand Down Expand Up @@ -153,6 +156,7 @@ jobs:
floatx: [float32, float64]
test-subset:
- |
pymc3/tests/test_initvals.py
pymc3/tests/test_distributions_random.py
pymc3/tests/test_distributions_timeseries.py
- |
Expand Down
2 changes: 1 addition & 1 deletion RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
- ⚠ PyMC3 now requires Scipy version `>= 1.4.1` (see [4857](https://github.com/pymc-devs/pymc3/pull/4857)).
- ArviZ `plots` and `stats` *wrappers* were removed. The functions are now just available by their original names (see [#4549](https://github.com/pymc-devs/pymc3/pull/4471) and `3.11.2` release notes).
- The GLM submodule has been removed, please use [Bambi](https://bambinos.github.io/bambi/) instead.
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`.
- The `Distribution` keyword argument `testval` has been deprecated in favor of `initval`. Furthermore `initval` no longer assigns a `tag.test_value` on tensors since the initial values are now kept track of by the model object ([see #4913](https://github.com/pymc-devs/pymc3/pull/4913)).
- `pm.sample` now returns results as `InferenceData` instead of `MultiTrace` by default (see [#4744](https://github.com/pymc-devs/pymc3/pull/4744)).
- `pm.sample_prior_predictive` no longer returns transformed variable values by default. Pass them by name in `var_names` if you want to obtain these draws (see [4769](https://github.com/pymc-devs/pymc3/pull/4769)).
- ...
Expand Down
12 changes: 4 additions & 8 deletions pymc3/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,9 @@ class Flat(Continuous):
rv_op = flat

@classmethod
def dist(cls, *, size=None, initval=None, **kwargs):
if initval is None:
initval = np.full(size, floatX(0.0))
def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
res.tag.test_value = initval
res.tag.test_value = np.full(size, floatX(0.0))
return res

def logp(value):
Expand Down Expand Up @@ -425,11 +423,9 @@ class HalfFlat(PositiveContinuous):
rv_op = halfflat

@classmethod
def dist(cls, *, size=None, initval=None, **kwargs):
if initval is None:
initval = np.full(size, floatX(1.0))
def dist(cls, *, size=None, **kwargs):
res = super().dist([], size=size, **kwargs)
res.tag.test_value = initval
res.tag.test_value = np.full(size, floatX(1.0))
return res

def logp(value):
Expand Down
37 changes: 20 additions & 17 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ def __new__(
)
dims = convert_dims(dims)

# Create the RV without specifying initval, because the initval may have a shape
# that only matches after replicating with a size implied by dims (see below).
rv_out = cls.dist(*args, rng=rng, initval=None, **kwargs)
# Create the RV without dims information, because that's not something tracked at the Aesara level.
# If necessary we'll later replicate to a different size implied by already known dims.
rv_out = cls.dist(*args, rng=rng, **kwargs)
ndim_actual = rv_out.ndim
resize_shape = None

Expand All @@ -219,12 +219,14 @@ def __new__(
# A batch size was specified through `dims`, or implied by `observed`.
rv_out = change_rv_size(rv_var=rv_out, new_size=resize_shape, expand=True)

if initval is not None:
# Assigning the testval earlier causes trouble because the RV may not be created with the final shape already.
rv_out.tag.test_value = initval

rv_out = model.register_rv(
rv_out, name, observed, total_size, dims=dims, transform=transform
rv_out,
name,
observed,
total_size,
dims=dims,
transform=transform,
initval=initval,
)

# add in pretty-printing support
Expand All @@ -242,7 +244,6 @@ def dist(
*,
shape: Optional[Shape] = None,
size: Optional[Size] = None,
initval=None,
**kwargs,
) -> RandomVariable:
"""Creates a RandomVariable corresponding to the `cls` distribution.
Expand All @@ -258,25 +259,27 @@ def dist(
all the dimensions that the RV would get if no shape/size/dims were passed at all.
size : int, tuple, Variable, optional
For creating the RV like in Aesara/NumPy.
initival : optional
Test value to be attached to the output RV.
Must match its shape exactly.

Returns
-------
rv : RandomVariable
The created RV.
"""
if "testval" in kwargs:
initval = kwargs.pop("testval")
kwargs.pop("testval")
warnings.warn(
"The `testval` argument is deprecated. "
"Use `initval` to set initial values for a `Model`; "
"otherwise, set test values on Aesara parameters explicitly "
"when attempting to use Aesara's test value debugging features.",
"The `.dist(testval=...)` argument is deprecated and has no effect. "
"Initial values for sampling/optimization can be specified with `initval` in a modelcontext. "
"For using Aesara's test value features, you must assign the `.tag.test_value` yourself.",
DeprecationWarning,
stacklevel=2,
)
if "initval" in kwargs:
raise TypeError(
"Unexpected keyword argument `initval`. "
"This argument is not available for the `.dist()` API."
)

if "dims" in kwargs:
raise NotImplementedError("The use of a `.dist(dims=...)` API is not supported.")
if shape is not None and size is not None:
Expand Down
47 changes: 34 additions & 13 deletions pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import scipy.sparse as sps

from aesara.compile.sharedvalue import SharedVariable
from aesara.gradient import grad
from aesara.graph.basic import Constant, Variable, graph_inputs
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.opt import local_subtensor_rv_lift
Expand Down Expand Up @@ -446,7 +445,7 @@ def __init__(
givens.append((var, shared))

if compute_grads:
grads = grad(cost, grad_vars, disconnected_inputs="ignore")
grads = aesara.grad(cost, grad_vars, disconnected_inputs="ignore")
for grad_wrt, var in zip(grads, grad_vars):
grad_wrt.name = f"{var.name}_grad"
outputs = [cost] + grads
Expand Down Expand Up @@ -648,7 +647,7 @@ def __init__(

# The sequence of model-generated RNGs
self.rng_seq = []
self.initial_values = {}
self._initial_values = {}

if self.parent is not None:
self.named_vars = treedict(parent=self.parent.named_vars)
Expand Down Expand Up @@ -912,17 +911,28 @@ def independent_vars(self):
return inputvars(self.unobserved_RVs)

@property
def test_point(self):
def test_point(self) -> Dict[str, np.ndarray]:
"""Deprecated alias for `Model.initial_point`."""
warnings.warn(
"`Model.test_point` has been deprecated. Use `Model.initial_point` instead.",
DeprecationWarning,
)
return self.initial_point

@property
def initial_point(self):
def initial_point(self) -> Dict[str, np.ndarray]:
"""Maps names of variables to initial values."""
return Point(list(self.initial_values.items()), model=self)

@property
def initial_values(self) -> Dict[TensorVariable, np.ndarray]:
"""Maps transformed variables to initial values.

⚠ The keys are NOT the objects returned by, `pm.Normal(...)`.
For a name-based dictionary use the `initial_point` property.
"""
return self._initial_values

@property
def disc_vars(self):
"""All the discrete variables in the model"""
Expand All @@ -934,11 +944,10 @@ def cont_vars(self):
return list(typefilter(self.value_vars, continuous_types))

def set_initval(self, rv_var, initval):
initval = (
rv_var.type.filter(initval)
if initval is not None
else getattr(rv_var.tag, "test_value", None)
)
if initval is not None:
initval = rv_var.type.filter(initval)

test_value = getattr(rv_var.tag, "test_value", None)

rv_value_var = self.rvs_to_values[rv_var]
transform = getattr(rv_value_var.tag, "transform", None)
Expand Down Expand Up @@ -972,7 +981,17 @@ def initval_to_rvval(value_var, value):
initval_fn = aesara.function(
[], rv_var, mode=mode, givens=givens, on_unused_input="ignore"
)
initval = initval_fn()
try:
initval = initval_fn()
except NotImplementedError as ex:
if "Cannot sample from" in ex.args[0]:
# The RV does not have a random number generator.
# Our last chance is to take the test_value.
# Note that this is a workaround for Flat and HalfFlat
# until an initval default mechanism is implemented (#4752).
initval = test_value
else:
raise

self.initial_values[rv_value_var] = initval

Expand Down Expand Up @@ -1530,7 +1549,7 @@ def update_start_vals(self, a: Dict[str, np.ndarray], b: Dict[str, np.ndarray]):
r"""Update point `a` with `b`, without overwriting existing keys.

Values specified for transformed variables in `a` will be recomputed
conditional on the valures of `b` and stored in `b`.
conditional on the values of `b` and stored in `b`.

"""
# TODO FIXME XXX: If we're going to incrementally update transformed
Expand Down Expand Up @@ -1717,14 +1736,16 @@ def fastfn(outs, mode=None, model=None):
return model.fastfn(outs, mode)


def Point(*args, filter_model_vars=False, **kwargs):
def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
"""Build a point. Uses same args as dict() does.
Filters out variables not in the model. All keys are strings.

Parameters
----------
args, kwargs
arguments to build a dict
filter_model_vars : bool
If `True`, only model variables are included in the result.
"""
model = modelcontext(kwargs.pop("model", None))
args = list(args)
Expand Down
49 changes: 49 additions & 0 deletions pymc3/tests/test_initvals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright 2020 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

import pymc3 as pm


def transform_fwd(rv, expected_untransformed):
return rv.tag.value_var.tag.transform.forward(rv, expected_untransformed).eval()


class TestInitvalAssignment:
def test_dist_warnings_and_errors(self):
with pytest.warns(DeprecationWarning, match="argument is deprecated and has no effect"):
rv = pm.Exponential.dist(lam=1, testval=0.5)
assert not hasattr(rv.tag, "test_value")

with pytest.raises(TypeError, match="Unexpected keyword argument `initval`."):
pm.Normal.dist(1, 2, initval=None)
pass

def test_new_warnings(self):
with pm.Model() as pmodel:
with pytest.warns(DeprecationWarning, match="`testval` argument is deprecated"):
rv = pm.Uniform("u", 0, 1, testval=0.75)
assert pmodel.initial_values[rv.tag.value_var] == transform_fwd(rv, 0.75)
assert not hasattr(rv.tag, "test_value")
pass


class TestSpecialDistributions:
def test_automatically_assigned_test_values(self):
# ...because they don't have random number generators.
rv = pm.Flat.dist()
assert hasattr(rv.tag, "test_value")
rv = pm.HalfFlat.dist()
assert hasattr(rv.tag, "test_value")
pass
11 changes: 10 additions & 1 deletion pymc3/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import pymc3 as pm

from pymc3.distributions.transforms import Transform
from pymc3.util import hash_key, hashable, locally_cachedmethod
from pymc3.util import UNSET, hash_key, hashable, locally_cachedmethod


class TestTransformName:
Expand Down Expand Up @@ -127,3 +127,12 @@ def some_method(self, x):

tc = TestClass()
assert tc.some_method(b1) != tc.some_method(b2)


def test_unset_repr(capsys):
def fn(a=UNSET):
return

help(fn)
captured = capsys.readouterr()
assert "a=UNSET" in captured.out
17 changes: 14 additions & 3 deletions pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,18 @@

from cachetools import LRUCache, cachedmethod

UNSET = object()

class _UnsetType:
"""Type for the `UNSET` object to make it look nice in `help(...)` outputs."""

def __str__(self):
return "UNSET"

def __repr__(self):
return str(self)


UNSET = _UnsetType()


def withparent(meth):
Expand Down Expand Up @@ -191,9 +202,9 @@ def get_default_varnames(var_iterator, include_transformed):
return [var for var in var_iterator if not is_transformed_name(get_var_name(var))]


def get_var_name(var):
def get_var_name(var) -> str:
"""Get an appropriate, plain variable name for a variable."""
return getattr(var, "name", str(var))
return str(getattr(var, "name", var))


def get_transformed(z):
Expand Down