Skip to content

Commit bf6cce0

Browse files
Enable MvNormal tests in test_distributions
1 parent e65aad2 commit bf6cce0

File tree

1 file changed

+14
-10
lines changed

1 file changed

+14
-10
lines changed

pymc3/tests/test_distributions.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,11 @@ def build_model(distfam, valuedomain, vardomains, extra_args=None):
230230
v_at.name = v
231231
param_vars[v] = v_at
232232
param_vars.update(extra_args)
233-
distfam("value", **param_vars, transform=None)
233+
distfam(
234+
"value",
235+
**param_vars,
236+
transform=None,
237+
)
234238
return m, param_vars
235239

236240

@@ -1473,6 +1477,9 @@ def test_beta_binomial(self):
14731477
{"alpha": Rplus, "beta": Rplus, "n": NatSmall},
14741478
lambda value, alpha, beta, n: sp.betabinom.logcdf(value, a=alpha, b=beta, n=n),
14751479
)
1480+
1481+
@pytest.mark.xfail(reason="Distribution not refactored yet")
1482+
def test_beta_binomial_selfconsistency(self):
14761483
self.check_selfconsistency_discrete_logcdf(
14771484
BetaBinomial,
14781485
Nat,
@@ -1611,14 +1618,14 @@ def test_zeroinflatedbinomial(self):
16111618
n_samples=10,
16121619
)
16131620

1614-
@pytest.mark.xfail(reason="Distribution not refactored yet")
16151621
@pytest.mark.parametrize("n", [1, 2, 3])
16161622
def test_mvnormal(self, n):
16171623
self.check_logp(
16181624
MvNormal,
16191625
RealMatrix(5, n),
16201626
{"mu": Vector(R, n), "tau": PdMatrix(n)},
16211627
normal_logpdf_tau,
1628+
extra_args={"size": 5},
16221629
)
16231630
self.check_logp(
16241631
MvNormal,
@@ -1631,6 +1638,7 @@ def test_mvnormal(self, n):
16311638
RealMatrix(5, n),
16321639
{"mu": Vector(R, n), "cov": PdMatrix(n)},
16331640
normal_logpdf_cov,
1641+
extra_args={"size": 5},
16341642
)
16351643
self.check_logp(
16361644
MvNormal,
@@ -1644,6 +1652,7 @@ def test_mvnormal(self, n):
16441652
{"mu": Vector(R, n), "chol": PdMatrixChol(n)},
16451653
normal_logpdf_chol,
16461654
decimal=select_by_precision(float64=6, float32=-1),
1655+
extra_args={"size": 5},
16471656
)
16481657
self.check_logp(
16491658
MvNormal,
@@ -1652,23 +1661,19 @@ def test_mvnormal(self, n):
16521661
normal_logpdf_chol,
16531662
decimal=select_by_precision(float64=6, float32=0),
16541663
)
1655-
1656-
def MvNormalUpper(*args, **kwargs):
1657-
return MvNormal(lower=False, *args, **kwargs)
1658-
16591664
self.check_logp(
1660-
MvNormalUpper,
1665+
MvNormal,
16611666
Vector(R, n),
16621667
{"mu": Vector(R, n), "chol": PdMatrixCholUpper(n)},
16631668
normal_logpdf_chol_upper,
16641669
decimal=select_by_precision(float64=6, float32=0),
1670+
extra_args={"lower": False},
16651671
)
16661672

16671673
@pytest.mark.xfail(
16681674
condition=(aesara.config.floatX == "float32"),
16691675
reason="Fails on float32 due to inf issues",
16701676
)
1671-
@pytest.mark.xfail(reason="Distribution not refactored yet")
16721677
def test_mvnormal_indef(self):
16731678
cov_val = np.array([[1, 0.5], [0.5, -2]])
16741679
cov = aet.matrix("cov")
@@ -1683,14 +1688,13 @@ def test_mvnormal_indef(self):
16831688
f_dlogp = aesara.function([cov, x], dlogp)
16841689
assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2))))
16851690

1686-
logp = logp(MvNormal.dist(mu=mu, tau=cov), x)
1691+
logp = logpt(MvNormal.dist(mu=mu, tau=cov), x)
16871692
f_logp = aesara.function([cov, x], logp)
16881693
assert f_logp(cov_val, np.ones(2)) == -np.inf
16891694
dlogp = aet.grad(logp, cov)
16901695
f_dlogp = aesara.function([cov, x], dlogp)
16911696
assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2))))
16921697

1693-
@pytest.mark.xfail(reason="Distribution not refactored yet")
16941698
def test_mvnormal_init_fail(self):
16951699
with Model():
16961700
with pytest.raises(ValueError):

0 commit comments

Comments
 (0)