Skip to content

Commit 6852a26

Browse files
committed
Fix logp test and move it back to test_distributions.py
1 parent 014b601 commit 6852a26

File tree

3 files changed

+21
-26
lines changed

3 files changed

+21
-26
lines changed

pymc/distributions/timeseries.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pymc.aesaraf import change_rv_size, floatX, intX
2424
from pymc.distributions import distribution, logprob, multivariate
2525
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
26+
from pymc.distributions.dist_math import check_parameters
2627
from pymc.distributions.shape_utils import to_tuple
2728

2829
__all__ = [
@@ -214,7 +215,11 @@ def logp(
214215
stationary_series = value[..., 1:] - value[..., :-1]
215216
series_logp = logprob.logp(Normal.dist(mu, sigma), stationary_series)
216217

217-
return init_logp + series_logp.sum(axis=-1)
218+
return check_parameters(
219+
init_logp + series_logp.sum(axis=-1),
220+
steps > 0,
221+
msg="steps > 0",
222+
)
218223

219224

220225
class AR1(distribution.Continuous):

pymc/tests/test_distributions.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2619,6 +2619,21 @@ def test_interpolated_transform(self, transform):
26192619
assert not np.isfinite(m.compile_logp()({"x": -1.0}))
26202620
assert not np.isfinite(m.compile_logp()({"x": 11.0}))
26212621

2622+
def test_grw(self):
2623+
def ref_logp(value, mu, sigma, steps):
2624+
# Relying on fact that init will be normal by default
2625+
return (
2626+
scipy.stats.norm.logpdf(value[0], mu, sigma)
2627+
+ scipy.stats.norm.logpdf(np.diff(value), mu, sigma).sum()
2628+
)
2629+
2630+
self.check_logp(
2631+
pm.GaussianRandomWalk,
2632+
Vector(R, 4),
2633+
{"mu": R, "sigma": Rplus, "steps": Nat},
2634+
ref_logp,
2635+
decimal=select_by_precision(float64=6, float32=1),
2636+
)
26222637

26232638
class TestBound:
26242639
"""Tests for pm.Bound distribution"""

pymc/tests/test_distributions_timeseries.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -89,31 +89,6 @@ def test_grw_inference(self):
8989
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
9090

9191

92-
class TestGRWScipy(td.TestMatchesScipy):
93-
# TODO: Test LogP for different inits in its own function
94-
95-
# TODO: Find issue that says GRW wont take vector
96-
def test_grw_logp(self):
97-
def grw_logp(value, mu, sigma):
98-
# Relying on fact that init will be normal
99-
# Note: This means we're not testing
100-
stationary_series = np.diff(value)
101-
logp = stats.norm.logpdf(value[0], mu, sigma) + \
102-
stats.norm.logpdf(stationary_series, mu, sigma).sum(),
103-
return logp
104-
105-
# TODO: Make base class a static method
106-
# TODO: Reuse this make this static so it doesnt run all other ones
107-
self.check_logp(
108-
pm.GaussianRandomWalk,
109-
td.Vector(td.R, 10),
110-
{"mu": td.R, "sigma": td.Rplus, "steps": td.Nat},
111-
grw_logp,
112-
decimal=select_by_precision(float64=6, float32=1),
113-
n_samples=1,
114-
)
115-
116-
11792
@pytest.mark.xfail(reason="Timeseries not refactored")
11893
def test_AR():
11994
# AR1

0 commit comments

Comments
 (0)