31
31
from pymc .tests .test_distributions_random import BaseTestDistributionRandom
32
32
33
33
34
- class TestGaussianRandomWalk (BaseTestDistributionRandom ):
34
+ class TestGaussianRandomWalkRandom (BaseTestDistributionRandom ):
35
35
# Override default size for test class
36
36
size = None
37
37
@@ -54,26 +54,23 @@ def check_rv_inferred_size(self):
54
54
expected_symbolic = tuple (pymc_rv .shape .eval ())
55
55
assert expected_symbolic == expected
56
56
57
- def check_not_implemented (self ):
58
- with pytest .raises (NotImplementedError ):
59
- self .pymc_rv .eval ()
60
57
61
- def test_grw_inference ( self ):
62
- mu , sigma , steps = 2 , 1 , 10000
63
- obs = np .concatenate ([[0 ], np .random .normal (mu , sigma , size = steps )]).cumsum ()
58
+ def test_gaussianrandomwalk_inference ( ):
59
+ mu , sigma , steps = 2 , 1 , 1000
60
+ obs = np .concatenate ([[0 ], np .random .normal (mu , sigma , size = steps )]).cumsum ()
64
61
65
- with pm .Model ():
66
- _mu = pm .Uniform ("mu" , - 10 , 10 )
67
- _sigma = pm .Uniform ("sigma" , 0 , 10 )
62
+ with pm .Model ():
63
+ _mu = pm .Uniform ("mu" , - 10 , 10 )
64
+ _sigma = pm .Uniform ("sigma" , 0 , 10 )
68
65
69
- obs_data = pm .MutableData ("obs_data" , obs )
70
- grw = GaussianRandomWalk ("grw" , _mu , _sigma , steps = steps , observed = obs_data )
66
+ obs_data = pm .MutableData ("obs_data" , obs )
67
+ grw = GaussianRandomWalk ("grw" , _mu , _sigma , steps = steps , observed = obs_data )
71
68
72
- trace = pm .sample (chains = 1 )
69
+ trace = pm .sample (chains = 1 )
73
70
74
- recovered_mu = trace .posterior ["mu" ].mean ()
75
- recovered_sigma = trace .posterior ["sigma" ].mean ()
76
- np .testing .assert_allclose ([mu , sigma ], [recovered_mu , recovered_sigma ], atol = 0.2 )
71
+ recovered_mu = trace .posterior ["mu" ].mean ()
72
+ recovered_sigma = trace .posterior ["sigma" ].mean ()
73
+ np .testing .assert_allclose ([mu , sigma ], [recovered_mu , recovered_sigma ], atol = 0.2 )
77
74
78
75
79
76
@pytest .mark .xfail (reason = "Timeseries not refactored" )
0 commit comments