@@ -69,14 +69,16 @@ def check_not_implemented(self):
69
69
with pytest .raises (NotImplementedError ):
70
70
self .pymc_rv .eval ()
71
71
72
+ # TODO: Move this out of this test class, not related, makes more sense together with
73
+ # the other logp tests
72
74
def test_grw_inference (self ):
73
75
mu , sigma , steps = 2 , 1 , 10000
74
76
obs = np .concatenate ([[0 ], np .random .normal (mu , sigma , size = steps )]).cumsum ()
75
77
76
78
with pm .Model ():
77
79
_mu = pm .Uniform ("mu" , - 10 , 10 )
78
80
_sigma = pm .Uniform ("sigma" , 0 , 10 )
79
- # Workaround for bug in `at.diff` when data is constant
81
+
80
82
obs_data = pm .MutableData ("obs_data" , obs )
81
83
grw = GaussianRandomWalk ("grw" , _mu , _sigma , steps = steps , observed = obs_data )
82
84
@@ -90,19 +92,18 @@ def test_grw_inference(self):
90
92
class TestGRWScipy (td .TestMatchesScipy ):
91
93
# TODO: Test LogP for different inits in its own function
92
94
93
- # TODO: Find issue that says GRW wont take vector
95
+ # TODO: Find issue that says GRW wont take vector
94
96
def test_grw_logp (self ):
95
97
def grw_logp (value , mu , sigma ):
96
98
# Relying on fact that init will be normal
97
- # Note: This means we're not testing
99
+ # Note: This means we're not testing
98
100
stationary_series = np .diff (value )
99
- logp = (
100
- stats .norm .logpdf (value [0 ], mu , sigma )
101
- + stats .norm .logpdf (stationary_series , mu , sigma ).sum (),
102
- )
101
+ logp = stats .norm .logpdf (value [0 ], mu , sigma ) + \
102
+ stats .norm .logpdf (stationary_series , mu , sigma ).sum (),
103
103
return logp
104
-
104
+
105
105
# TODO: Make base class a static method
106
+ # TODO: Reuse this make this static so it doesnt run all other ones
106
107
self .check_logp (
107
108
pm .GaussianRandomWalk ,
108
109
td .Vector (td .R , 10 ),
0 commit comments