Skip to content

Commit aac869e

Browse files
committed
Update test and timeseries more
1 parent cc5aafd commit aac869e

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

pymc/distributions/timeseries.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __new__(cls, name, mu=0.0, sigma=1.0, init=None, steps: int = 1, **kwargs):
150150
@classmethod
151151
def dist(
152152
cls, mu=0.0, sigma=1.0, init=None, steps: int = 1, size=None, **kwargs
153-
) -> TensorVariable:
153+
) -> at.TensorVariable:
154154

155155
mu = at.as_tensor_variable(floatX(mu))
156156
sigma = at.as_tensor_variable(floatX(sigma))

pymc/tests/test_distributions_timeseries.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def test_grw_inference(self):
8888

8989

9090
class TestGRWScipy(td.TestMatchesScipy):
91+
# TODO: Test LogP for different inits in its own function
9192

9293
# TODO: Find issue that says GRW wont take vector
9394
def test_grw_logp(self):
@@ -99,13 +100,14 @@ def grw_logp(value, mu, sigma):
99100
stats.norm.logpdf(stationary_series, mu, sigma).sum(),
100101
return logp
101102

102-
# TODO: Make base class static static method
103+
# TODO: Make base class a static method
103104
self.check_logp(
104105
pm.GaussianRandomWalk,
105106
td.Vector(td.R, 10),
106107
{"mu": td.R, "sigma": td.Rplus, "steps": td.Nat},
107108
grw_logp,
108109
decimal=select_by_precision(float64=6, float32=1),
110+
n_samples=1,
109111
)
110112

111113

0 commit comments

Comments
 (0)