|
1 | 1 | import shutil
|
2 | 2 | import tempfile
|
| 3 | +import warnings |
3 | 4 |
|
4 | 5 | from .checks import close_to
|
5 | 6 | from .models import simple_categorical, mv_simple, mv_simple_discrete, simple_2model, mv_prior_simple
|
|
9 | 10 | Metropolis, Slice, CompoundStep, NormalProposal,
|
10 | 11 | MultivariateNormalProposal, HamiltonianMC,
|
11 | 12 | EllipticalSlice, smc)
|
12 |
| -from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical |
| 13 | +from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical, Beta |
13 | 14 |
|
14 | 15 | from numpy.testing import assert_array_almost_equal
|
15 | 16 | import numpy as np
|
@@ -268,15 +269,17 @@ def test_mv_proposal(self):
|
268 | 269 | class TestCompoundStep(object):
|
269 | 270 | samplers = (Metropolis, Slice, HamiltonianMC, NUTS)
|
270 | 271 |
|
271 |
| - @pytest.mark.skipif(theano.config.floatX == "float32", reason="Test fails on 32 bit due to linalg issues") |
| 272 | + @pytest.mark.skipif(theano.config.floatX == "float32", |
| 273 | + reason="Test fails on 32 bit due to linalg issues") |
272 | 274 | def test_non_blocked(self):
|
273 | 275 | """Test that samplers correctly create non-blocked compound steps."""
|
274 | 276 | _, model = simple_2model()
|
275 | 277 | with model:
|
276 | 278 | for sampler in self.samplers:
|
277 | 279 | assert isinstance(sampler(blocked=False), CompoundStep)
|
278 | 280 |
|
279 |
| - @pytest.mark.skipif(theano.config.floatX == "float32", reason="Test fails on 32 bit due to linalg issues") |
| 281 | + @pytest.mark.skipif(theano.config.floatX == "float32", |
| 282 | + reason="Test fails on 32 bit due to linalg issues") |
280 | 283 | def test_blocked(self):
|
281 | 284 | _, model = simple_2model()
|
282 | 285 | with model:
|
@@ -318,3 +321,15 @@ def test_binomial(self):
|
318 | 321 | Binomial('x', 10, 0.5)
|
319 | 322 | steps = assign_step_methods(model, [])
|
320 | 323 | assert isinstance(steps, Metropolis)
|
| 324 | + |
| 325 | + |
| 326 | +class TestNutsCheckTrace(object): |
| 327 | + def test_multiple_samplers(self): |
| 328 | + with Model(): |
| 329 | + prob = Beta('prob', alpha=5, beta=3) |
| 330 | + Binomial('outcome', n=1, p=prob) |
| 331 | + with warnings.catch_warnings(record=True) as warns: |
| 332 | + sample(5, n_init=None, tune=2) |
| 333 | + messages = [warn.message.args[0] for warn in warns] |
| 334 | + assert any("contains only 5" in msg for msg in messages) |
| 335 | + assert all('boolean index did not' not in msg for msg in messages) |
0 commit comments