Skip to content

Commit a817a7e

Browse files
Add size tests to new rv testing framework
1 parent b1c40ef commit a817a7e

File tree

2 files changed

+47
-11
lines changed

2 files changed

+47
-11
lines changed

pymc3/tests/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def setup_method(self):
4141
def teardown_method(self):
4242
set_at_rng(self.old_at_rng)
4343

44-
def get_random_state(self):
45-
if self.random_state is None:
44+
def get_random_state(self, reset=False):
45+
if self.random_state is None or reset:
4646
self.random_state = nr.RandomState(self.random_seed)
4747
return self.random_state
4848

pymc3/tests/test_distributions_random.py

Lines changed: 45 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,13 +431,7 @@ class BaseTestDistribution(SeededTest):
431431
decimal = 6
432432

433433
def test_distribution(self) -> None:
434-
with pm.Model():
435-
self.pymc_dist_output = self.pymc_dist(
436-
**self.pymc_dist_params,
437-
size=self.size,
438-
rng=aesara.shared(self.get_random_state()),
439-
name=f"{self.pymc_dist.rv_op.name}_test",
440-
)
434+
self._instantiate_pymc_distribution()
441435
if self.expected_dist is not None:
442436
self.expected_dist_outcome = self.expected_dist()(
443437
**self.expected_dist_params, size=self.size
@@ -449,11 +443,23 @@ def run_test(self, test_name):
449443
{
450444
"check_pymc_dist_matches_expected": self._check_pymc_draws_match_expected,
451445
"check_pymc_params_match_rv_op": self._check_pymc_params_match_rv_op,
446+
"check_distribution_size": self._check_distribution_size,
452447
}[test_name]()
453448

449+
def _instantiate_pymc_distribution(self):
450+
with pm.Model():
451+
self.pymc_dist_output = self.pymc_dist(
452+
**self.pymc_dist_params,
453+
size=self.size,
454+
rng=aesara.shared(self.get_random_state(reset=True)),
455+
name=f"{self.pymc_dist.rv_op.name}_test",
456+
)
457+
454458
def _check_pymc_draws_match_expected(
455459
self,
456460
):
461+
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
462+
self._instantiate_pymc_distribution()
457463
assert_array_almost_equal(
458464
self.pymc_dist_output.eval(), self.expected_dist_outcome, decimal=self.decimal
459465
)
@@ -469,6 +475,28 @@ def _check_pymc_params_match_rv_op(self) -> None:
469475
):
470476
assert_almost_equal(expected_value, actual_variable.eval(), decimal=self.decimal)
471477

478+
def _check_distribution_size(self):
479+
sizes_to_check, sizes_expected = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)], [
480+
(),
481+
(),
482+
(1,),
483+
(1,),
484+
(5,),
485+
(4, 5),
486+
(2, 4, 2),
487+
]
488+
for size, expected in zip(sizes_to_check, sizes_expected):
489+
pymc_dist_output_resized = change_rv_size(self.pymc_dist_output, size)
490+
actual = pymc_dist_output_resized.eval().shape
491+
print(actual, expected)
492+
assert actual == expected
493+
494+
# test negative sizes raise
495+
with pytest.raises(ValueError):
496+
change_rv_size(self.pymc_dist_output, -2).eval()
497+
with pytest.raises(ValueError):
498+
change_rv_size(self.pymc_dist_output, (3, -2)).eval()
499+
472500

473501
def seeded_scipy_distribution_builder(dist_name: str) -> Callable:
474502
return lambda self: functools.partial(
@@ -489,7 +517,11 @@ class TestGumbelDistribution(BaseTestDistribution):
489517
expected_dist_params = {"loc": 1.5, "scale": 3.0}
490518
size = 15
491519
expected_dist = seeded_scipy_distribution_builder("gumbel_r")
492-
tests_to_run = ["check_pymc_params_match_rv_op", "check_pymc_dist_matches_expected"]
520+
tests_to_run = [
521+
"check_pymc_params_match_rv_op",
522+
"check_distribution_size",
523+
"check_pymc_dist_matches_expected",
524+
]
493525

494526

495527
class TestNormalDistribution(BaseTestDistribution):
@@ -499,7 +531,11 @@ class TestNormalDistribution(BaseTestDistribution):
499531
expected_dist_params = {"loc": 5.0, "scale": 10.0}
500532
size = 15
501533
expected_dist = seeded_numpy_distribution_builder("normal")
502-
tests_to_run = ["check_pymc_params_match_rv_op", "check_pymc_dist_matches_expected"]
534+
tests_to_run = [
535+
"check_pymc_params_match_rv_op",
536+
"check_distribution_size",
537+
"check_pymc_dist_matches_expected",
538+
]
503539

504540

505541
class TestNormalTauDistribution(BaseTestDistribution):

0 commit comments

Comments
 (0)