Skip to content

Improve sampling coverage #4270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,10 @@ jobs:
pymc3/tests/test_distributions_timeseries.py
pymc3/tests/test_parallel_sampling.py
pymc3/tests/test_random.py
pymc3/tests/test_sampling.py
pymc3/tests/test_shared.py
pymc3/tests/test_smc.py
- |
pymc3/tests/test_examples.py
pymc3/tests/test_gp.py
pymc3/tests/test_mixture.py
pymc3/tests/test_posteriors.py
pymc3/tests/test_quadpotential.py
Expand All @@ -54,6 +52,8 @@ jobs:
pymc3/tests/test_variational_inference.py
- |
pymc3/tests/test_distributions.py
pymc3/tests/test_gp.py
pymc3/tests/test_sampling.py
runs-on: ${{ matrix.os }}
env:
TEST_SUBSET: ${{ matrix.test-subset }}
Expand Down
24 changes: 6 additions & 18 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@

"""Functions for MCMC sampling."""

from typing import Dict, List, Optional, TYPE_CHECKING, cast, Union, Any

if TYPE_CHECKING:
from typing import Tuple
from typing import Dict, List, Optional, cast, Union, Any
from typing import Iterable as TIterable
from collections.abc import Iterable
from collections import defaultdict
Expand Down Expand Up @@ -218,11 +215,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None


def _print_step_hierarchy(s, level=0):
if isinstance(s, (list, tuple)):
_log.info(">" * level + "list")
for i in s:
_print_step_hierarchy(i, level + 1)
elif isinstance(s, CompoundStep):
if isinstance(s, CompoundStep):
_log.info(">" * level + "CompoundStep")
for i in s.methods:
_print_step_hierarchy(i, level + 1)
Expand Down Expand Up @@ -458,7 +451,7 @@ def sample(

if return_inferencedata is None:
v = packaging.version.parse(pm.__version__)
if v.release[0] > 3 or v.release[1] >= 10:
if v.release[0] > 3 or v.release[1] >= 10: # type: ignore
warnings.warn(
"In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. "
"You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.",
Expand Down Expand Up @@ -585,7 +578,7 @@ def sample(
UserWarning,
)
_print_step_hierarchy(step)
trace = _sample_population(**sample_args, parallelize=cores > 1)
trace = _sample_population(parallelize=cores > 1, **sample_args)
else:
_log.info(f"Sequential sampling ({chains} chains in 1 job)")
_print_step_hierarchy(step)
Expand Down Expand Up @@ -770,11 +763,9 @@ def _sample_population(
trace : MultiTrace
Contains samples of all chains
"""
# create the generator that iterates all chains in parallel
chains = [chain + c for c in range(chains)]
sampling = _prepare_iter_population(
draws,
chains,
[chain + c for c in range(chains)],
step,
start,
parallelize,
Expand Down Expand Up @@ -1582,10 +1573,7 @@ def insert(self, k: str, v, idx: int):
ids: int
The index of the sample we are inserting into the trace.
"""
if hasattr(v, "shape"):
value_shape = tuple(v.shape) # type: Tuple[int, ...]
else:
value_shape = ()
value_shape = np.shape(v)

# initialize if necessary
if k not in self.trace_dict:
Expand Down
30 changes: 17 additions & 13 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,9 @@
# limitations under the License.

from itertools import combinations
import packaging
from typing import Tuple
import numpy as np

try:
import unittest.mock as mock # py3
except ImportError:
from unittest import mock
import unittest.mock as mock

import numpy.testing as npt
import arviz as az
Expand Down Expand Up @@ -180,13 +175,9 @@ def test_trace_report_bart(self):
assert var_imp[0] > var_imp[1:].sum()
npt.assert_almost_equal(var_imp.sum(), 1)

def test_return_inferencedata(self):
def test_return_inferencedata(self, monkeypatch):
with self.model:
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())
v = packaging.version.parse(pm.__version__)
if v.major > 3 or v.minor >= 10:
with pytest.warns(FutureWarning, match="pass return_inferencedata"):
result = pm.sample(**kwargs)

# trace with tuning
with pytest.warns(UserWarning, match="will be included"):
Expand All @@ -203,12 +194,25 @@ def test_return_inferencedata(self):
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) > 0

# inferencedata without tuning
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=True)
# inferencedata without tuning, with idata_kwargs
prior = pm.sample_prior_predictive()
result = pm.sample(
**kwargs,
return_inferencedata=True,
discard_tuned_samples=True,
idata_kwargs={"prior": prior},
random_seed=-1
)
assert "prior" in result
assert isinstance(result, az.InferenceData)
assert result.posterior.sizes["draw"] == 100
assert result.posterior.sizes["chain"] == 2
assert len(result._groups_warmup) == 0

# check warning for version 3.10 onwards
monkeypatch.setattr("pymc3.__version__", "3.10")
with pytest.warns(FutureWarning, match="pass return_inferencedata"):
result = pm.sample(**kwargs)
pass
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unnecessary pass statement

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like to put them there to mark the end of a test/function (just like an unnecessary return).

In my opinion they have a few advantages:

  • trailing comments are collapsed together with the code
  • because the pass in a test, or return in a function was put there on purpose, you don't wonder if somebody just stopped writing, or accidentally deleted some code

Not saying you must put it back in - just showing a maybe new perspective :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me that looks weird and confusing.


@pytest.mark.parametrize("cores", [1, 2])
Expand Down
7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
[tool.black]
line-length = 100

[tool.coverage.report]
exclude_lines = [
"pragma: nocover",
"raise NotImplementedError",
"if TYPE_CHECKING:",
]

[tool.nbqa.mutate]
isort = 1
black = 1
Expand Down