12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ import aesara .tensor as at
15
16
import numpy as np
16
17
import pytest
17
18
19
+ from scipy import stats
20
+
18
21
import pymc as pm
19
22
20
23
from pymc .aesaraf import floatX
@@ -58,28 +61,25 @@ def test_grw_rv_op_shape(self, kwargs, expected):
58
61
assert grw .shape == expected
59
62
60
63
def test_grw_logp (self ):
61
- vals = [0 , 1 , 2 ]
64
+ # `at.diff` is currently broken with constants
65
+ test_vals = [0 , 1 , 2 ]
66
+ vals = at .vector ("vals" )
62
67
mu = 1
63
68
sigma = 1
64
- init = pm .Normal .dist (mu , sigma )
69
+ init = pm .Normal .dist (0 , sigma )
65
70
66
71
with pm .Model ():
67
72
grw = GaussianRandomWalk ("grw" , mu , sigma , init = init , steps = 2 )
68
73
69
74
logp = pm .logp (grw , vals )
75
+ logp_eval = logp .eval ({vals : test_vals })
70
76
71
- with pytest . raises ( TypeError ) as err :
72
- logp_vals = logp . eval ( )
73
-
74
- assert "Cannot convert Type TensorType(float" . lower () in str ( err ). lower ( )
77
+ logp_reference = (
78
+ stats . norm ( 0 , sigma ). logpdf ( test_vals [ 0 ] )
79
+ + stats . norm ( mu , sigma ). logpdf ( np . diff ( test_vals )). sum ()
80
+ )
75
81
76
- # logp_reference = []
77
- #
78
- # for x_minus_one_val, x_val in zip(vals, vals[1:]):
79
- # logp_point = stats.norm(x_minus_one_val + mu + init, sigma).logpdf(x_val)
80
- # logp_reference.append(logp_point)
81
- #
82
- # np.testing.assert_almost_equal(logp_vals, logp_reference)
82
+ np .testing .assert_almost_equal (logp_eval , logp_reference )
83
83
84
84
def test_grw_inference (self ):
85
85
mu , sigma , steps = 2 , 1 , 10000
@@ -88,16 +88,15 @@ def test_grw_inference(self):
88
88
with pm .Model ():
89
89
_mu = pm .Uniform ("mu" , - 10 , 10 )
90
90
_sigma = pm .Uniform ("sigma" , 0 , 10 )
91
- grw = GaussianRandomWalk ("grw" , _mu , _sigma , steps = steps , observed = obs )
92
-
93
- with pytest .raises (TypeError ) as err :
94
- trace = pm .sample ()
91
+ # Workaround for bug in `at.diff` when data is constant
92
+ obs_data = pm .MutableData ("obs_data" , obs )
93
+ grw = GaussianRandomWalk ("grw" , _mu , _sigma , steps = steps , observed = obs_data )
95
94
96
- assert "cannot convert type tensortype(float" . lower () in str ( err ). lower ()
95
+ trace = pm . sample ()
97
96
98
- # recovered_mu = trace.posterior["mu"].mean()
99
- # recovered_sigma = trace.posterior["sigma"].mean()
100
- # np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
97
+ recovered_mu = trace .posterior ["mu" ].mean ()
98
+ recovered_sigma = trace .posterior ["sigma" ].mean ()
99
+ np .testing .assert_allclose ([mu , sigma ], [recovered_mu , recovered_sigma ], atol = 0.2 )
101
100
102
101
@pytest .mark .parametrize (
103
102
"steps,size,expected" ,
0 commit comments