Skip to content

Commit c67dd8b

Browse files
committed
Move inference test out of unrelated class
1 parent 1b615c8 commit c67dd8b

File tree

1 file changed

+13
-16
lines changed

1 file changed

+13
-16
lines changed

pymc/tests/test_distributions_timeseries.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pymc.tests.test_distributions_random import BaseTestDistributionRandom
3232

3333

34-
class TestGaussianRandomWalk(BaseTestDistributionRandom):
34+
class TestGaussianRandomWalkRandom(BaseTestDistributionRandom):
3535
# Override default size for test class
3636
size = None
3737

@@ -54,26 +54,23 @@ def check_rv_inferred_size(self):
5454
expected_symbolic = tuple(pymc_rv.shape.eval())
5555
assert expected_symbolic == expected
5656

57-
def check_not_implemented(self):
58-
with pytest.raises(NotImplementedError):
59-
self.pymc_rv.eval()
6057

61-
def test_grw_inference(self):
62-
mu, sigma, steps = 2, 1, 10000
63-
obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum()
58+
def test_gaussianrandomwalk_inference():
59+
mu, sigma, steps = 2, 1, 1000
60+
obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum()
6461

65-
with pm.Model():
66-
_mu = pm.Uniform("mu", -10, 10)
67-
_sigma = pm.Uniform("sigma", 0, 10)
62+
with pm.Model():
63+
_mu = pm.Uniform("mu", -10, 10)
64+
_sigma = pm.Uniform("sigma", 0, 10)
6865

69-
obs_data = pm.MutableData("obs_data", obs)
70-
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
66+
obs_data = pm.MutableData("obs_data", obs)
67+
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
7168

72-
trace = pm.sample(chains=1)
69+
trace = pm.sample(chains=1)
7370

74-
recovered_mu = trace.posterior["mu"].mean()
75-
recovered_sigma = trace.posterior["sigma"].mean()
76-
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
71+
recovered_mu = trace.posterior["mu"].mean()
72+
recovered_sigma = trace.posterior["sigma"].mean()
73+
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
7774

7875

7976
@pytest.mark.xfail(reason="Timeseries not refactored")

0 commit comments

Comments
 (0)