|
34 | 34 | from pymc.tests.test_distributions_random import BaseTestDistributionRandom
|
35 | 35 |
|
36 | 36 |
|
| 37 | +@pytest.mark.skip("These are old tests that I may need delete") |
37 | 38 | class TestGaussianRandomWalk:
|
38 | 39 | @pytest.mark.parametrize(
|
39 | 40 | "kwargs,expected",
|
@@ -80,22 +81,6 @@ def test_grw_logp(self):
|
80 | 81 |
|
81 | 82 | np.testing.assert_almost_equal(logp_eval, logp_reference, decimal=6)
|
82 | 83 |
|
83 |
| - def test_grw_inference(self): |
84 |
| - mu, sigma, steps = 2, 1, 10000 |
85 |
| - obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum() |
86 |
| - |
87 |
| - with pm.Model(): |
88 |
| - _mu = pm.Uniform("mu", -10, 10) |
89 |
| - _sigma = pm.Uniform("sigma", 0, 10) |
90 |
| - # Workaround for bug in `at.diff` when data is constant |
91 |
| - obs_data = pm.MutableData("obs_data", obs) |
92 |
| - grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data) |
93 |
| - |
94 |
| - trace = pm.sample() |
95 |
| - |
96 |
| - recovered_mu = trace.posterior["mu"].mean() |
97 |
| - recovered_sigma = trace.posterior["sigma"].mean() |
98 |
| - np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2) |
99 | 84 |
|
100 | 85 | @pytest.mark.parametrize(
|
101 | 86 | "steps,size,expected",
|
@@ -156,6 +141,22 @@ def check_not_implemented(self):
|
156 | 141 | with pytest.raises(NotImplementedError):
|
157 | 142 | self.pymc_rv.eval()
|
158 | 143 |
|
| 144 | + def test_grw_inference(self): |
| 145 | + mu, sigma, steps = 2, 1, 10000 |
| 146 | + obs = np.concatenate([[0], np.random.normal(mu, sigma, size=steps)]).cumsum() |
| 147 | + |
| 148 | + with pm.Model(): |
| 149 | + _mu = pm.Uniform("mu", -10, 10) |
| 150 | + _sigma = pm.Uniform("sigma", 0, 10) |
| 151 | + # Workaround for bug in `at.diff` when data is constant |
| 152 | + obs_data = pm.MutableData("obs_data", obs) |
| 153 | + grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data) |
| 154 | + |
| 155 | + trace = pm.sample() |
| 156 | + |
| 157 | + recovered_mu = trace.posterior["mu"].mean() |
| 158 | + recovered_sigma = trace.posterior["sigma"].mean() |
| 159 | + np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2) |
159 | 160 |
|
160 | 161 | @pytest.mark.xfail(reason="Timeseries not refactored")
|
161 | 162 | def test_AR():
|
|
0 commit comments