Skip to content

Commit 91b91c2

Browse files
Re-enable Arviz parts of sampler fixtures
1 parent 8915d62 commit 91b91c2

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pymc3/tests/sampler_fixtures.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@
1616
import arviz as az
1717
import numpy as np
1818
import numpy.testing as npt
19-
import pytest
2019

2120
from scipy import stats
2221

2322
import pymc3 as pm
2423

24+
from pymc3.backends.arviz import to_inference_data
2525
from pymc3.tests.helpers import SeededTest
2626
from pymc3.util import get_var_name
2727

@@ -153,16 +153,16 @@ def setup_class(cls):
153153
for var in cls.model.unobserved_RVs:
154154
cls.samples[get_var_name(var)] = cls.trace.get_values(var, burn=cls.burn)
155155

156-
@pytest.mark.xfail(reason="Arviz not refactored for v4")
157156
def test_neff(self):
158157
if hasattr(self, "min_n_eff"):
159-
n_eff = az.ess(self.trace[self.burn :])
158+
idata = to_inference_data(self.trace[self.burn :])
159+
n_eff = az.ess(idata)
160160
for var in n_eff:
161161
npt.assert_array_less(self.min_n_eff, n_eff[var])
162162

163-
@pytest.mark.xfail(reason="Arviz not refactored for v4")
164163
def test_Rhat(self):
165-
rhat = az.rhat(self.trace[self.burn :])
164+
idata = to_inference_data(self.trace[self.burn :])
165+
rhat = az.rhat(idata)
166166
for var in rhat:
167167
npt.assert_allclose(rhat[var], 1, rtol=0.01)
168168

0 commit comments

Comments
 (0)