Skip to content

Commit 9f8d459

Browse files
Use untransformed samples and xfail Arviz tests in BaseSampler
1 parent 1c68482 commit 9f8d459

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

pymc3/tests/sampler_fixtures.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import arviz as az
1717
import numpy as np
1818
import numpy.testing as npt
19+
import pytest
1920

2021
from scipy import stats
2122

@@ -140,17 +141,26 @@ def setup_class(cls):
140141
cls.model = cls.make_model()
141142
with cls.model:
142143
cls.step = cls.make_step()
143-
cls.trace = pm.sample(cls.n_samples, tune=cls.tune, step=cls.step, cores=cls.chains)
144+
cls.trace = pm.sample(
145+
cls.n_samples,
146+
tune=cls.tune,
147+
step=cls.step,
148+
cores=cls.chains,
149+
return_inferencedata=False,
150+
compute_convergence_checks=False,
151+
)
144152
cls.samples = {}
145153
for var in cls.model.unobserved_RVs:
146-
cls.samples[get_var_name(var)] = cls.trace.get_values(var.tag.value_var, burn=cls.burn)
154+
cls.samples[get_var_name(var)] = cls.trace.get_values(var, burn=cls.burn)
147155

156+
@pytest.mark.xfail(reason="Arviz not refactored for v4")
148157
def test_neff(self):
149158
if hasattr(self, "min_n_eff"):
150159
n_eff = az.ess(self.trace[self.burn :])
151160
for var in n_eff:
152161
npt.assert_array_less(self.min_n_eff, n_eff[var])
153162

163+
@pytest.mark.xfail(reason="Arviz not refactored for v4")
154164
def test_Rhat(self):
155165
rhat = az.rhat(self.trace[self.burn :])
156166
for var in rhat:

pymc3/tests/test_posteriors.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ class TestNUTSBetaBinomial(sf.NutsFixture, sf.BetaBinomialFixture):
7676
min_n_eff = 400
7777

7878

79+
@pytest.mark.xfail(reason="StudentT not refactored for v4")
7980
class TestNUTSStudentT(sf.NutsFixture, sf.StudentTFixture):
8081
n_samples = 10000
8182
tune = 1000
@@ -97,6 +98,7 @@ class TestNUTSNormalLong(sf.NutsFixture, sf.NormalFixture):
9798
atol = 0.001
9899

99100

101+
@pytest.mark.xfail(reason="StudentT not refactored for v4")
100102
class TestNUTSLKJCholeskyCov(sf.NutsFixture, sf.LKJCholeskyCovFixture):
101103
n_samples = 2000
102104
tune = 1000

0 commit comments

Comments
 (0)