|
24 | 24 | import scipy.stats as st
|
25 | 25 |
|
26 | 26 | from numpy.testing import assert_almost_equal
|
27 |
| -from scipy import linalg |
28 | 27 | from scipy.special import expit
|
29 | 28 |
|
30 | 29 | import pymc3 as pm
|
31 | 30 |
|
32 | 31 | from pymc3.aesaraf import floatX, intX
|
33 | 32 | from pymc3.distributions import change_rv_size
|
| 33 | +from pymc3.distributions.multivariate import quaddist_matrix |
34 | 34 | from pymc3.distributions.shape_utils import to_tuple
|
35 | 35 | from pymc3.exceptions import ShapeError
|
36 | 36 | from pymc3.tests.helpers import SeededTest
|
|
41 | 41 | NatSmall,
|
42 | 42 | PdMatrix,
|
43 | 43 | PdMatrixChol,
|
44 |
| - PdMatrixCholUpper, |
45 | 44 | R,
|
46 | 45 | RandomPdMatrix,
|
47 | 46 | RealMatrix,
|
@@ -644,6 +643,34 @@ def test_poisson(self):
|
644 | 643 | params = [("mu", 4)]
|
645 | 644 | self._pymc_params_match_rv_ones(params, params, pm.Poisson)
|
646 | 645 |
|
| 646 | + def test_mv_distribution(self): |
| 647 | + params = [("mu", np.array([1.0, 2.0])), ("cov", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 648 | + self._pymc_params_match_rv_ones(params, params, pm.MvNormal) |
| 649 | + |
| 650 | + def test_mv_distribution_chol(self): |
| 651 | + params = [("mu", np.array([1.0, 2.0])), ("chol", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 652 | + expected_cov = quaddist_matrix(chol=params[1][1]) |
| 653 | + expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())] |
| 654 | + self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal) |
| 655 | + |
| 656 | + def test_mv_distribution_tau(self): |
| 657 | + params = [("mu", np.array([1.0, 2.0])), ("tau", np.array([[2.0, 0.0], [0.0, 3.5]]))] |
| 658 | + expected_cov = quaddist_matrix(tau=params[1][1]) |
| 659 | + expected_params = [("mu", np.array([1.0, 2.0])), ("cov", expected_cov.eval())] |
| 660 | + self._pymc_params_match_rv_ones(params, expected_params, pm.MvNormal) |
| 661 | + |
| 662 | + def test_dirichlet(self): |
| 663 | + params = [("a", np.array([1.0, 2.0]))] |
| 664 | + self._pymc_params_match_rv_ones(params, params, pm.Dirichlet) |
| 665 | + |
| 666 | + def test_multinomial(self): |
| 667 | + params = [("n", 85), ("p", np.array([0.28, 0.62, 0.10]))] |
| 668 | + self._pymc_params_match_rv_ones(params, params, pm.Multinomial) |
| 669 | + |
| 670 | + def test_categorical(self): |
| 671 | + params = [("p", np.array([0.28, 0.62, 0.10]))] |
| 672 | + self._pymc_params_match_rv_ones(params, params, pm.Categorical) |
| 673 | + |
647 | 674 |
|
648 | 675 | class TestScalarParameterSamples(SeededTest):
|
649 | 676 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
@@ -840,66 +867,13 @@ def ref_rand(size, q, beta):
|
840 | 867 | pm.DiscreteWeibull, {"q": Unit, "beta": Rplusdunif}, ref_rand=ref_rand
|
841 | 868 | )
|
842 | 869 |
|
843 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
844 |
| - @pytest.mark.parametrize("s", [2, 3, 4]) |
845 |
| - def test_categorical_random(self, s): |
846 |
| - def ref_rand(size, p): |
847 |
| - return nr.choice(np.arange(p.shape[0]), p=p, size=size) |
848 |
| - |
849 |
| - pymc3_random_discrete(pm.Categorical, {"p": Simplex(s)}, ref_rand=ref_rand) |
850 |
| - |
851 | 870 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
852 | 871 | def test_constant_dist(self):
|
853 | 872 | def ref_rand(size, c):
|
854 | 873 | return c * np.ones(size, dtype=int)
|
855 | 874 |
|
856 | 875 | pymc3_random_discrete(pm.Constant, {"c": I}, ref_rand=ref_rand)
|
857 | 876 |
|
858 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
859 |
| - def test_mv_normal(self): |
860 |
| - def ref_rand(size, mu, cov): |
861 |
| - return st.multivariate_normal.rvs(mean=mu, cov=cov, size=size) |
862 |
| - |
863 |
| - def ref_rand_tau(size, mu, tau): |
864 |
| - return ref_rand(size, mu, linalg.inv(tau)) |
865 |
| - |
866 |
| - def ref_rand_chol(size, mu, chol): |
867 |
| - return ref_rand(size, mu, np.dot(chol, chol.T)) |
868 |
| - |
869 |
| - def ref_rand_uchol(size, mu, chol): |
870 |
| - return ref_rand(size, mu, np.dot(chol.T, chol)) |
871 |
| - |
872 |
| - for n in [2, 3]: |
873 |
| - pymc3_random( |
874 |
| - pm.MvNormal, |
875 |
| - {"mu": Vector(R, n), "cov": PdMatrix(n)}, |
876 |
| - size=100, |
877 |
| - valuedomain=Vector(R, n), |
878 |
| - ref_rand=ref_rand, |
879 |
| - ) |
880 |
| - pymc3_random( |
881 |
| - pm.MvNormal, |
882 |
| - {"mu": Vector(R, n), "tau": PdMatrix(n)}, |
883 |
| - size=100, |
884 |
| - valuedomain=Vector(R, n), |
885 |
| - ref_rand=ref_rand_tau, |
886 |
| - ) |
887 |
| - pymc3_random( |
888 |
| - pm.MvNormal, |
889 |
| - {"mu": Vector(R, n), "chol": PdMatrixChol(n)}, |
890 |
| - size=100, |
891 |
| - valuedomain=Vector(R, n), |
892 |
| - ref_rand=ref_rand_chol, |
893 |
| - ) |
894 |
| - pymc3_random( |
895 |
| - pm.MvNormal, |
896 |
| - {"mu": Vector(R, n), "chol": PdMatrixCholUpper(n)}, |
897 |
| - size=100, |
898 |
| - valuedomain=Vector(R, n), |
899 |
| - ref_rand=ref_rand_uchol, |
900 |
| - extra_args={"lower": False}, |
901 |
| - ) |
902 |
| - |
903 | 877 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
904 | 878 | def test_matrix_normal(self):
|
905 | 879 | def ref_rand(size, mu, rowcov, colcov):
|
@@ -1042,20 +1016,6 @@ def ref_rand(size, nu, Sigma, mu):
|
1042 | 1016 | ref_rand=ref_rand,
|
1043 | 1017 | )
|
1044 | 1018 |
|
1045 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
1046 |
| - def test_dirichlet(self): |
1047 |
| - def ref_rand(size, a): |
1048 |
| - return st.dirichlet.rvs(a, size=size) |
1049 |
| - |
1050 |
| - for n in [2, 3]: |
1051 |
| - pymc3_random( |
1052 |
| - pm.Dirichlet, |
1053 |
| - {"a": Vector(Rplus, n)}, |
1054 |
| - valuedomain=Simplex(n), |
1055 |
| - size=100, |
1056 |
| - ref_rand=ref_rand, |
1057 |
| - ) |
1058 |
| - |
1059 | 1019 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1060 | 1020 | def test_dirichlet_multinomial(self):
|
1061 | 1021 | def ref_rand(size, a, n):
|
@@ -1123,20 +1083,6 @@ def test_dirichlet_multinomial_dist_ShapeError(self, n, a, shape, expectation):
|
1123 | 1083 | with expectation:
|
1124 | 1084 | m.random()
|
1125 | 1085 |
|
1126 |
| - @pytest.mark.skip(reason="This test is covered by Aesara") |
1127 |
| - def test_multinomial(self): |
1128 |
| - def ref_rand(size, p, n): |
1129 |
| - return nr.multinomial(pvals=p, n=n, size=size) |
1130 |
| - |
1131 |
| - for n in [2, 3]: |
1132 |
| - pymc3_random_discrete( |
1133 |
| - pm.Multinomial, |
1134 |
| - {"p": Simplex(n), "n": Nat}, |
1135 |
| - valuedomain=Vector(Nat, n), |
1136 |
| - size=100, |
1137 |
| - ref_rand=ref_rand, |
1138 |
| - ) |
1139 |
| - |
1140 | 1086 | @pytest.mark.xfail(reason="This distribution has not been refactored for v4")
|
1141 | 1087 | def test_gumbel(self):
|
1142 | 1088 | def ref_rand(size, mu, beta):
|
|
0 commit comments