Skip to content

Commit 6e70f89

Browse files
committed
Update GRW Test
1 parent 9efa6af commit 6e70f89

File tree

1 file changed

+13
-2
lines changed

1 file changed

+13
-2
lines changed

pymc/tests/test_distributions_timeseries.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,23 @@ def test_grw_inference(self):
8888

8989

9090
class TestGRWScipy(td.TestMatchesScipy):
91-
def test_grw(self):
91+
92+
# TODO: Find issue that says GRW wont take vector
93+
def test_grw_logp(self):
94+
def grw_logp(value, mu, sigma):
95+
# Relying on fact that init will be normal
96+
# Note: This means we're not testing
97+
stationary_series = np.diff(value)
98+
logp = stats.norm.logpdf(value[0], mu, sigma) + \
99+
stats.norm.logpdf(stationary_series, mu, sigma).sum(),
100+
return logp
101+
102+
# TODO: Make base class static static method
92103
self.check_logp(
93104
pm.GaussianRandomWalk,
94105
td.Vector(td.R, 10),
95106
{"mu": td.R, "sigma": td.Rplus, "steps": td.Nat},
96-
lambda value, mu, sigma: stats.norm.logpdf(value, mu, sigma).cumsum().sum(),
107+
grw_logp,
97108
decimal=select_by_precision(float64=6, float32=1),
98109
)
99110

0 commit comments

Comments
 (0)