Skip to content

Commit 1c88e55

Browse files
Add tests for multivariate and for univariate multi-parameters
1 parent a817a7e commit 1c88e55

File tree

1 file changed

+46
-24
lines changed

1 file changed

+46
-24
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import sys
1717

1818
from contextlib import ExitStack as does_not_raise
19-
from typing import Callable
19+
from typing import Callable, List, Optional
2020

2121
import aesara
2222
import numpy as np
@@ -421,17 +421,21 @@ class TestMoyal(BaseTestCases.BaseTestCase):
421421

422422

423423
class BaseTestDistribution(SeededTest):
424-
pymc_dist = None
424+
pymc_dist: Optional[Callable] = None
425425
pymc_dist_params = dict()
426-
expected_dist = None
426+
expected_dist: Optional[Callable] = None
427427
expected_dist_params = dict()
428428
expected_rv_op_params = dict()
429429
tests_to_run = []
430430
size = 15
431431
decimal = 6
432432

433-
def test_distribution(self) -> None:
434-
self._instantiate_pymc_distribution()
433+
sizes_to_check: Optional[List] = None
434+
sizes_expected: Optional[List] = None
435+
repeated_params_shape = 5
436+
437+
def test_distribution(self):
438+
self._instantiate_pymc_rv()
435439
if self.expected_dist is not None:
436440
self.expected_dist_outcome = self.expected_dist()(
437441
**self.expected_dist_params, size=self.size
@@ -446,20 +450,19 @@ def run_test(self, test_name):
446450
"check_distribution_size": self._check_distribution_size,
447451
}[test_name]()
448452

449-
def _instantiate_pymc_distribution(self):
453+
def _instantiate_pymc_rv(self, dist_params=None):
454+
params = dist_params if dist_params else self.pymc_dist_params
450455
with pm.Model():
451456
self.pymc_dist_output = self.pymc_dist(
452-
**self.pymc_dist_params,
457+
**params,
453458
size=self.size,
454459
rng=aesara.shared(self.get_random_state(reset=True)),
455460
name=f"{self.pymc_dist.rv_op.name}_test",
456461
)
457462

458-
def _check_pymc_draws_match_expected(
459-
self,
460-
):
463+
def _check_pymc_draws_match_expected(self):
461464
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
462-
self._instantiate_pymc_distribution()
465+
self._instantiate_pymc_rv()
463466
assert_array_almost_equal(
464467
self.pymc_dist_output.eval(), self.expected_dist_outcome, decimal=self.decimal
465468
)
@@ -476,7 +479,9 @@ def _check_pymc_params_match_rv_op(self) -> None:
476479
assert_almost_equal(expected_value, actual_variable.eval(), decimal=self.decimal)
477480

478481
def _check_distribution_size(self):
479-
sizes_to_check, sizes_expected = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)], [
482+
# test sizes
483+
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
484+
sizes_expected = self.sizes_expected or [
480485
(),
481486
(),
482487
(1,),
@@ -486,16 +491,29 @@ def _check_distribution_size(self):
486491
(2, 4, 2),
487492
]
488493
for size, expected in zip(sizes_to_check, sizes_expected):
489-
pymc_dist_output_resized = change_rv_size(self.pymc_dist_output, size)
490-
actual = pymc_dist_output_resized.eval().shape
491-
print(actual, expected)
494+
actual = change_rv_size(self.pymc_dist_output, size).eval().shape
492495
assert actual == expected
493496

494497
# test negative sizes raise
495-
with pytest.raises(ValueError):
496-
change_rv_size(self.pymc_dist_output, -2).eval()
497-
with pytest.raises(ValueError):
498-
change_rv_size(self.pymc_dist_output, (3, -2)).eval()
498+
for size in [-2, (3, -2)]:
499+
with pytest.raises(ValueError):
500+
change_rv_size(self.pymc_dist_output, size).eval()
501+
502+
# test multi-parameters sampling for univariate distributions
503+
if self.pymc_dist.rv_op.ndim_supp == 0:
504+
params = {
505+
k: p * np.ones(self.repeated_params_shape) for k, p in self.pymc_dist_params.items()
506+
}
507+
self._instantiate_pymc_rv(params)
508+
sizes_to_check = [None, self.repeated_params_shape, (5, self.repeated_params_shape)]
509+
sizes_expected = [
510+
(self.repeated_params_shape,),
511+
(self.repeated_params_shape,),
512+
(5, self.repeated_params_shape),
513+
]
514+
for size, expected in zip(sizes_to_check, sizes_expected):
515+
actual = change_rv_size(self.pymc_dist_output, size).eval().shape
516+
assert actual == expected
499517

500518

501519
def seeded_scipy_distribution_builder(dist_name: str) -> Callable:
@@ -706,7 +724,7 @@ class TestPoissonDistribution(BaseTestDistribution):
706724
tests_to_run = ["check_pymc_params_match_rv_op"]
707725

708726

709-
class TestMVNormalDistributionDistribution(BaseTestDistribution):
727+
class TestMvNormalDistributionDistribution(BaseTestDistribution):
710728
pymc_dist = pm.MvNormal
711729
pymc_dist_params = {
712730
"mu": np.array([1.0, 2.0]),
@@ -716,10 +734,12 @@ class TestMVNormalDistributionDistribution(BaseTestDistribution):
716734
"mu": np.array([1.0, 2.0]),
717735
"cov": np.array([[2.0, 0.0], [0.0, 3.5]]),
718736
}
719-
tests_to_run = ["check_pymc_params_match_rv_op"]
737+
sizes_to_check = [None, (1), (2, 3)]
738+
sizes_expected = [(2,), (1, 2), (2, 3, 2)]
739+
tests_to_run = ["check_pymc_params_match_rv_op", "check_distribution_size"]
720740

721741

722-
class TestMVNormalDistributionCholDistribution(BaseTestDistribution):
742+
class TestMvNormalDistributionCholDistribution(BaseTestDistribution):
723743
pymc_dist = pm.MvNormal
724744
pymc_dist_params = {
725745
"mu": np.array([1.0, 2.0]),
@@ -732,7 +752,7 @@ class TestMVNormalDistributionCholDistribution(BaseTestDistribution):
732752
tests_to_run = ["check_pymc_params_match_rv_op"]
733753

734754

735-
class TestMVNormalDistributionTauDistribution(BaseTestDistribution):
755+
class TestMvNormalDistributionTauDistribution(BaseTestDistribution):
736756
pymc_dist = pm.MvNormal
737757
pymc_dist_params = {
738758
"mu": np.array([1.0, 2.0]),
@@ -756,7 +776,9 @@ class TestMultinomialDistribution(BaseTestDistribution):
756776
pymc_dist = pm.Multinomial
757777
pymc_dist_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
758778
expected_rv_op_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
759-
tests_to_run = ["check_pymc_params_match_rv_op"]
779+
sizes_to_check = [None, (1), (4,), (3, 2)]
780+
sizes_expected = [(3,), (1, 3), (4, 3), (3, 2, 3)]
781+
tests_to_run = ["check_pymc_params_match_rv_op", "check_distribution_size"]
760782

761783

762784
class TestCategoricalDistribution(BaseTestDistribution):

0 commit comments

Comments
 (0)