Skip to content

Commit a312231

Browse files
Change tests for refactored distributions
More details on commit id 0773620b6f599423315035b97ef082ad32d98fd4 comment.
1 parent 199451d commit a312231

File tree

1 file changed

+29
-83
lines changed

1 file changed

+29
-83
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 29 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
import scipy.stats as st
2525

2626
from numpy.testing import assert_almost_equal
27-
from scipy import linalg
2827
from scipy.special import expit
2928

3029
import pymc3 as pm
3130

3231
from pymc3.aesaraf import floatX, intX
3332
from pymc3.distributions import change_rv_size
33+
from pymc3.distributions.multivariate import quaddist_matrix
3434
from pymc3.distributions.shape_utils import to_tuple
3535
from pymc3.exceptions import ShapeError
3636
from pymc3.tests.helpers import SeededTest
@@ -41,7 +41,6 @@
4141
NatSmall,
4242
PdMatrix,
4343
PdMatrixChol,
44-
PdMatrixCholUpper,
4544
R,
4645
RandomPdMatrix,
4746
RealMatrix,
@@ -644,6 +643,34 @@ def test_poisson(self):
644643
params = [("mu", 4)]
645644
self._pymc_params_match_rv_ones(params, params, pm.Poisson)
646645

646+
def test_mv_distribution(self):
647+
params = [("mu", np.array([1.0, 2.0])), ("cov", np.array([[2.0, 0.0], [0.0, 3.5]]))]
648+
self._pymc_params_match_rv_ones(params, params, pm.MvNormal)
649+
650+
def test_mv_distribution_chol(self):
651+
params = [("mu", np.array([1.0, 2.0])), ("chol", np.array([[2.0, 0.0], [0.0, 3.5]]))]
652+
expected_cov = quaddist_matrix(chol=params[1][1])
653+
expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())]
654+
self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal)
655+
656+
def test_mv_distribution_tau(self):
657+
params = [("mu", np.array([1.0, 2.0])), ("tau", np.array([[2.0, 0.0], [0.0, 3.5]]))]
658+
expected_cov = quaddist_matrix(tau=params[1][1])
659+
expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())]
660+
self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal)
661+
662+
def test_dirichlet(self):
663+
params = [("a", np.array([1.0, 2.0]))]
664+
self._pymc_params_match_rv_ones(params, params, pm.Dirichlet)
665+
666+
def test_multinomial(self):
667+
params = [("n", 85), ("p", np.array([0.28, 0.62, 0.10]))]
668+
self._pymc_params_match_rv_ones(params, params, pm.Multinomial)
669+
670+
def test_categorical(self):
671+
params = [("p", np.array([0.28, 0.62, 0.10]))]
672+
self._pymc_params_match_rv_ones(params, params, pm.Categorical)
673+
647674

648675
class TestScalarParameterSamples(SeededTest):
649676
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
@@ -840,66 +867,13 @@ def ref_rand(size, q, beta):
840867
pm.DiscreteWeibull, {"q": Unit, "beta": Rplusdunif}, ref_rand=ref_rand
841868
)
842869

843-
@pytest.mark.skip(reason="This test is covered by Aesara")
844-
@pytest.mark.parametrize("s", [2, 3, 4])
845-
def test_categorical_random(self, s):
846-
def ref_rand(size, p):
847-
return nr.choice(np.arange(p.shape[0]), p=p, size=size)
848-
849-
pymc3_random_discrete(pm.Categorical, {"p": Simplex(s)}, ref_rand=ref_rand)
850-
851870
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
852871
def test_constant_dist(self):
853872
def ref_rand(size, c):
854873
return c * np.ones(size, dtype=int)
855874

856875
pymc3_random_discrete(pm.Constant, {"c": I}, ref_rand=ref_rand)
857876

858-
@pytest.mark.skip(reason="This test is covered by Aesara")
859-
def test_mv_normal(self):
860-
def ref_rand(size, mu, cov):
861-
return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size)
862-
863-
def ref_rand_tau(size, mu, tau):
864-
return ref_rand(size, mu, linalg.inv(tau))
865-
866-
def ref_rand_chol(size, mu, chol):
867-
return ref_rand(size, mu, np.dot(chol, chol.T))
868-
869-
def ref_rand_uchol(size, mu, chol):
870-
return ref_rand(size, mu, np.dot(chol.T, chol))
871-
872-
for n in [2, 3]:
873-
pymc3_random(
874-
pm.MvNormal,
875-
{"mu": Vector(R, n), "cov": PdMatrix(n)},
876-
size=100,
877-
valuedomain=Vector(R, n),
878-
ref_rand=ref_rand,
879-
)
880-
pymc3_random(
881-
pm.MvNormal,
882-
{"mu": Vector(R, n), "tau": PdMatrix(n)},
883-
size=100,
884-
valuedomain=Vector(R, n),
885-
ref_rand=ref_rand_tau,
886-
)
887-
pymc3_random(
888-
pm.MvNormal,
889-
{"mu": Vector(R, n), "chol": PdMatrixChol(n)},
890-
size=100,
891-
valuedomain=Vector(R, n),
892-
ref_rand=ref_rand_chol,
893-
)
894-
pymc3_random(
895-
pm.MvNormal,
896-
{"mu": Vector(R, n), "chol": PdMatrixCholUpper(n)},
897-
size=100,
898-
valuedomain=Vector(R, n),
899-
ref_rand=ref_rand_uchol,
900-
extra_args={"lower": False},
901-
)
902-
903877
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
904878
def test_matrix_normal(self):
905879
def ref_rand(size, mu, rowcov, colcov):
@@ -1042,20 +1016,6 @@ def ref_rand(size, nu, Sigma, mu):
10421016
ref_rand=ref_rand,
10431017
)
10441018

1045-
@pytest.mark.skip(reason="This test is covered by Aesara")
1046-
def test_dirichlet(self):
1047-
def ref_rand(size, a):
1048-
return st.dirichlet.rvs(a, size=size)
1049-
1050-
for n in [2, 3]:
1051-
pymc3_random(
1052-
pm.Dirichlet,
1053-
{"a": Vector(Rplus, n)},
1054-
valuedomain=Simplex(n),
1055-
size=100,
1056-
ref_rand=ref_rand,
1057-
)
1058-
10591019
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
10601020
def test_dirichlet_multinomial(self):
10611021
def ref_rand(size, a, n):
@@ -1123,20 +1083,6 @@ def test_dirichlet_multinomial_dist_ShapeError(self, n, a, shape, expectation):
11231083
with expectation:
11241084
m.random()
11251085

1126-
@pytest.mark.skip(reason="This test is covered by Aesara")
1127-
def test_multinomial(self):
1128-
def ref_rand(size, p, n):
1129-
return nr.multinomial(pvals=p, n=n, size=size)
1130-
1131-
for n in [2, 3]:
1132-
pymc3_random_discrete(
1133-
pm.Multinomial,
1134-
{"p": Simplex(n), "n": Nat},
1135-
valuedomain=Vector(Nat, n),
1136-
size=100,
1137-
ref_rand=ref_rand,
1138-
)
1139-
11401086
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
11411087
def test_gumbel(self):
11421088
def ref_rand(size, mu, beta):

0 commit comments

Comments
 (0)