Skip to content

Commit e690356

Browse files
committed
Update tests
1 parent fabc8ad commit e690356

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

pymc/tests/test_distributions.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3420,8 +3420,7 @@ def test_no_warning_logp(self):
34203420
"sd_dist",
34213421
[
34223422
pm.Exponential.dist(1),
3423-
# For some reason this test is running when TimeSeries test is run
3424-
# pytest.mark.xfail(pm.MvNormal.dist(np.ones(3), np.eye(3)),)
3423+
pytest.mark.xfail(pm.MvNormal.dist(np.ones(3), np.eye(3)),)
34253424
],
34263425
)
34273426
def test_sd_dist_automatically_resized(self, sd_dist):

pymc/tests/test_distributions_timeseries.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from pymc.tests.test_distributions_random import BaseTestDistributionRandom
3535

3636

37+
@pytest.mark.skip("These are old tests that I may need delete")
3738
class TestGaussianRandomWalk:
3839
@pytest.mark.parametrize(
3940
"kwargs,expected",
@@ -80,22 +81,6 @@ def test_grw_logp(self):
8081

8182
np.testing.assert_almost_equal(logp_eval, logp_reference, decimal=6)
8283

83-
def test_grw_inference(self):
84-
mu, sigma, steps = 2, 1, 10000
85-
obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum()
86-
87-
with pm.Model():
88-
_mu = pm.Uniform("mu", -10, 10)
89-
_sigma = pm.Uniform("sigma", 0, 10)
90-
# Workaround for bug in `at.diff` when data is constant
91-
obs_data = pm.MutableData("obs_data", obs)
92-
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
93-
94-
trace = pm.sample()
95-
96-
recovered_mu = trace.posterior["mu"].mean()
97-
recovered_sigma = trace.posterior["sigma"].mean()
98-
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
9984

10085
@pytest.mark.parametrize(
10186
"steps,size,expected",
@@ -156,6 +141,22 @@ def check_not_implemented(self):
156141
with pytest.raises(NotImplementedError):
157142
self.pymc_rv.eval()
158143

144+
def test_grw_inference(self):
145+
mu, sigma, steps = 2, 1, 10000
146+
obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum()
147+
148+
with pm.Model():
149+
_mu = pm.Uniform("mu", -10, 10)
150+
_sigma = pm.Uniform("sigma", 0, 10)
151+
# Workaround for bug in `at.diff` when data is constant
152+
obs_data = pm.MutableData("obs_data", obs)
153+
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
154+
155+
trace = pm.sample()
156+
157+
recovered_mu = trace.posterior["mu"].mean()
158+
recovered_sigma = trace.posterior["sigma"].mean()
159+
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
159160

160161
@pytest.mark.xfail(reason="Timeseries not refactored")
161162
def test_AR():

0 commit comments

Comments
 (0)