Skip to content

Commit 4b6bb80

Browse files
committed
Add tests for MvGaussianRandomWalk with RV params
Added tests for both chol and cov to be random variables.
1 parent f3ec9ca commit 4b6bb80

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,3 +1734,39 @@ def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
17341734
)
17351735
output_shape = to_tuple(sample_shape) + dist_shape
17361736
assert dist.random(size=sample_shape).shape == output_shape
1737+
1738+
@pytest.mark.xfail
1739+
@pytest.mark.parametrize(
1740+
["sample_shape", "dist_shape", "mu_shape"],
1741+
generate_shapes(include_params=False),
1742+
ids=str,
1743+
)
1744+
def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
1745+
with pm.Model() as model:
1746+
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
1747+
sd_dist = pm.Exponential.dist(1.0, shape=3)
1748+
chol, corr, stds = pm.LKJCholeskyCov(
1749+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
1750+
)
1751+
mv = pm.MvGaussianRandomWalk("mv", mu, chol=chol, shape=dist_shape)
1752+
prior = pm.sample_prior_predictive(samples=sample_shape)
1753+
1754+
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
1755+
1756+
@pytest.mark.xfail
1757+
@pytest.mark.parametrize(
1758+
["sample_shape", "dist_shape", "mu_shape"],
1759+
generate_shapes(include_params=False),
1760+
ids=str,
1761+
)
1762+
def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
1763+
with pm.Model() as model:
1764+
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
1765+
sd_dist = pm.Exponential.dist(1.0, shape=3)
1766+
chol, corr, stds = pm.LKJCholeskyCov(
1767+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
1768+
)
1769+
mv = pm.MvGaussianRandomWalk("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
1770+
prior = pm.sample_prior_predictive(samples=sample_shape)
1771+
1772+
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape

0 commit comments

Comments
 (0)