Skip to content

Commit 2c494c9

Browse files
committed
Fix rebase issue
1 parent 111f9c5 commit 2c494c9

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -310,12 +310,6 @@ class TestGaussianRandomWalk(BaseTestCases.BaseTestCase):
310310
default_shape = (1,)
311311

312312

313-
class TestMvGaussianRandomWalk(BaseTestCases.BaseTestCase):
314-
distribution = pm.MvGaussianRandomWalk
315-
params = {"mu": np.array([1.0, 0.0]), "cov": np.array([[1.0, 0.0], [0.0, 2.0]])}
316-
default_shape = (10, 2)
317-
318-
319313
class TestNormal(BaseTestCases.BaseTestCase):
320314
distribution = pm.Normal
321315
params = {"mu": 0.0, "tau": 1.0}
@@ -1726,3 +1720,17 @@ def test_matrix_normal_random_with_random_variables():
17261720
prior = pm.sample_prior_predictive(2)
17271721

17281722
assert prior["mu"].shape == (2, D, K)
1723+
1724+
1725+
class TestMvGaussianRandomWalk(SeededTest):
1726+
@pytest.mark.parametrize(
1727+
["sample_shape", "dist_shape", "mu_shape", "param"],
1728+
generate_shapes(include_params=True),
1729+
ids=str,
1730+
)
1731+
def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
1732+
dist = pm.MvGaussianRandomWalk.dist(
1733+
mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape
1734+
)
1735+
output_shape = to_tuple(sample_shape) + dist_shape
1736+
assert dist.random(size=sample_shape).shape == output_shape

0 commit comments

Comments
 (0)