Skip to content

Commit 2171790

Browse files
committed
Add temporary workaround for bug it at.diff in tests
1 parent d23458e commit 2171790

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

pymc/tests/test_distributions_timeseries.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import aesara.tensor as at
1516
import numpy as np
1617
import pytest
1718

19+
from scipy import stats
20+
1821
import pymc as pm
1922

2023
from pymc.aesaraf import floatX
@@ -58,28 +61,25 @@ def test_grw_rv_op_shape(self, kwargs, expected):
5861
assert grw.shape == expected
5962

6063
def test_grw_logp(self):
61-
vals = [0, 1, 2]
64+
# `at.diff` is currently broken with constants
65+
test_vals = [0, 1, 2]
66+
vals = at.vector("vals")
6267
mu = 1
6368
sigma = 1
64-
init = pm.Normal.dist(mu, sigma)
69+
init = pm.Normal.dist(0, sigma)
6570

6671
with pm.Model():
6772
grw = GaussianRandomWalk("grw", mu, sigma, init=init, steps=2)
6873

6974
logp = pm.logp(grw, vals)
75+
logp_eval = logp.eval({vals: test_vals})
7076

71-
with pytest.raises(TypeError) as err:
72-
logp_vals = logp.eval()
73-
74-
assert "Cannot convert Type TensorType(float".lower() in str(err).lower()
77+
logp_reference = (
78+
stats.norm(0, sigma).logpdf(test_vals[0])
79+
+ stats.norm(mu, sigma).logpdf(np.diff(test_vals)).sum()
80+
)
7581

76-
# logp_reference = []
77-
#
78-
# for x_minus_one_val, x_val in zip(vals, vals[1:]):
79-
# logp_point = stats.norm(x_minus_one_val + mu + init, sigma).logpdf(x_val)
80-
# logp_reference.append(logp_point)
81-
#
82-
# np.testing.assert_almost_equal(logp_vals, logp_reference)
82+
np.testing.assert_almost_equal(logp_eval, logp_reference)
8383

8484
def test_grw_inference(self):
8585
mu, sigma, steps = 2, 1, 10000
@@ -88,16 +88,15 @@ def test_grw_inference(self):
8888
with pm.Model():
8989
_mu = pm.Uniform("mu", -10, 10)
9090
_sigma = pm.Uniform("sigma", 0, 10)
91-
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs)
92-
93-
with pytest.raises(TypeError) as err:
94-
trace = pm.sample()
91+
# Workaround for bug in `at.diff` when data is constant
92+
obs_data = pm.MutableData("obs_data", obs)
93+
grw = GaussianRandomWalk("grw", _mu, _sigma, steps=steps, observed=obs_data)
9594

96-
assert "cannot convert type tensortype(float".lower() in str(err).lower()
95+
trace = pm.sample()
9796

98-
# recovered_mu = trace.posterior["mu"].mean()
99-
# recovered_sigma = trace.posterior["sigma"].mean()
100-
# np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
97+
recovered_mu = trace.posterior["mu"].mean()
98+
recovered_sigma = trace.posterior["sigma"].mean()
99+
np.testing.assert_allclose([mu, sigma], [recovered_mu, recovered_sigma], atol=0.2)
101100

102101
@pytest.mark.parametrize(
103102
"steps,size,expected",

0 commit comments

Comments
 (0)