Skip to content

Commit d60d753

Browse files
aseyboldtColCarroll
authored andcommitted
Fix NUTS.check_trace for several samplers (#2145)
1 parent 4822cf0 commit d60d753

File tree

2 files changed

+22
-7
lines changed

2 files changed

+22
-7
lines changed

pymc3/step_methods/hmc/nuts.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,19 +215,19 @@ def check_trace(self, strace):
215215

216216
diverging = strace.get_sampler_stats('diverging')
217217
if diverging.ndim == 2:
218-
diverging = np.any(diverging, axis=0)
218+
diverging = np.any(diverging, axis=-1)
219219

220220
tuning = strace.get_sampler_stats('tune')
221221
if tuning.ndim == 2:
222-
tuning = np.any(tuning, axis=0)
222+
tuning = np.any(tuning, axis=-1)
223223

224224
accept = strace.get_sampler_stats('mean_tree_accept')
225225
if accept.ndim == 2:
226-
accept = np.mean(accept, axis=0)
226+
accept = np.mean(accept, axis=-1)
227227

228228
depth = strace.get_sampler_stats('depth')
229229
if depth.ndim == 2:
230-
depth = np.max(depth, axis=0)
230+
depth = np.max(depth, axis=-1)
231231

232232
n_samples = n - (~tuning).sum()
233233

pymc3/tests/test_step.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import shutil
22
import tempfile
3+
import warnings
34

45
from .checks import close_to
56
from .models import simple_categorical, mv_simple, mv_simple_discrete, simple_2model, mv_prior_simple
@@ -9,7 +10,7 @@
910
Metropolis, Slice, CompoundStep, NormalProposal,
1011
MultivariateNormalProposal, HamiltonianMC,
1112
EllipticalSlice, smc)
12-
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical
13+
from pymc3.distributions import Binomial, Normal, Bernoulli, Categorical, Beta
1314

1415
from numpy.testing import assert_array_almost_equal
1516
import numpy as np
@@ -268,15 +269,17 @@ def test_mv_proposal(self):
268269
class TestCompoundStep(object):
269270
samplers = (Metropolis, Slice, HamiltonianMC, NUTS)
270271

271-
@pytest.mark.skipif(theano.config.floatX == "float32", reason="Test fails on 32 bit due to linalg issues")
272+
@pytest.mark.skipif(theano.config.floatX == "float32",
273+
reason="Test fails on 32 bit due to linalg issues")
272274
def test_non_blocked(self):
273275
"""Test that samplers correctly create non-blocked compound steps."""
274276
_, model = simple_2model()
275277
with model:
276278
for sampler in self.samplers:
277279
assert isinstance(sampler(blocked=False), CompoundStep)
278280

279-
@pytest.mark.skipif(theano.config.floatX == "float32", reason="Test fails on 32 bit due to linalg issues")
281+
@pytest.mark.skipif(theano.config.floatX == "float32",
282+
reason="Test fails on 32 bit due to linalg issues")
280283
def test_blocked(self):
281284
_, model = simple_2model()
282285
with model:
@@ -318,3 +321,15 @@ def test_binomial(self):
318321
Binomial('x', 10, 0.5)
319322
steps = assign_step_methods(model, [])
320323
assert isinstance(steps, Metropolis)
324+
325+
326+
class TestNutsCheckTrace(object):
327+
def test_multiple_samplers(self):
328+
with Model():
329+
prob = Beta('prob', alpha=5, beta=3)
330+
Binomial('outcome', n=1, p=prob)
331+
with warnings.catch_warnings(record=True) as warns:
332+
sample(5, n_init=None, tune=2)
333+
messages = [warn.message.args[0] for warn in warns]
334+
assert any("contains only 5" in msg for msg in messages)
335+
assert all('boolean index did not' not in msg for msg in messages)

0 commit comments

Comments
 (0)