Skip to content

Commit 55b4a0f

Browse files
fix few names
1 parent bf68a3a commit 55b4a0f

File tree

1 file changed

+29
-43
lines changed

1 file changed

+29
-43
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 29 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,6 @@ class TestVonMises(BaseTestCases.BaseTestCase):
337337
params = {"mu": 0.0, "kappa": 1.0}
338338

339339

340-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
341-
class TestGumbel(BaseTestCases.BaseTestCase):
342-
distribution = pm.Gumbel
343-
params = {"mu": 0.0, "beta": 1.0}
344-
345-
346340
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
347341
class TestLogistic(BaseTestCases.BaseTestCase):
348342
distribution = pm.Logistic
@@ -417,8 +411,8 @@ class TestMoyal(BaseTestCases.BaseTestCase):
417411
class BaseTestDistribution(SeededTest):
418412
pymc_dist: Optional[Callable] = None
419413
pymc_dist_params = dict()
420-
expected_dist: Optional[Callable] = None
421-
expected_dist_params = dict()
414+
reference_dist: Optional[Callable] = None
415+
reference_dist_params = dict()
422416
expected_rv_op_params = dict()
423417
tests_to_run = []
424418
size = 15
@@ -430,40 +424,40 @@ class BaseTestDistribution(SeededTest):
430424

431425
def test_distribution(self):
432426
self._instantiate_pymc_rv()
433-
if self.expected_dist is not None:
434-
self.expected_dist_outcome = self.expected_dist()(
435-
**self.expected_dist_params, size=self.size
427+
if self.reference_dist is not None:
428+
self.reference_dist_draws = self.reference_dist()(
429+
**self.reference_dist_params, size=self.size
436430
)
437431
for test_name in self.tests_to_run:
438432
self.run_test(test_name)
439433

440434
def run_test(self, test_name):
441435
{
442-
"check_pymc_dist_matches_expected": self._check_pymc_draws_match_expected,
436+
"check_pymc_dist_matches_reference": self._check_pymc_draws_match_reference,
443437
"check_pymc_params_match_rv_op": self._check_pymc_params_match_rv_op,
444-
"check_distribution_size": self._check_distribution_size,
438+
"check_rv_size": self._check_rv_size,
445439
}[test_name]()
446440

447441
def _instantiate_pymc_rv(self, dist_params=None):
448442
params = dist_params if dist_params else self.pymc_dist_params
449443
with pm.Model():
450-
self.pymc_dist_output = self.pymc_dist(
444+
self.pymc_rv = self.pymc_dist(
451445
**params,
452446
size=self.size,
453447
rng=aesara.shared(self.get_random_state(reset=True)),
454448
name=f"{self.pymc_dist.rv_op.name}_test",
455449
)
456450

457-
def _check_pymc_draws_match_expected(self):
451+
def _check_pymc_draws_match_reference(self):
458452
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
459453
self._instantiate_pymc_rv()
460454
assert_array_almost_equal(
461-
self.pymc_dist_output.eval(), self.expected_dist_outcome, decimal=self.decimal
455+
self.pymc_rv.eval(), self.reference_dist_draws, decimal=self.decimal
462456
)
463457

464458
def _check_pymc_params_match_rv_op(self) -> None:
465459
try:
466-
aesera_dist_inputs = self.pymc_dist_output.get_parents()[0].inputs[3:]
460+
aesera_dist_inputs = self.pymc_rv.get_parents()[0].inputs[3:]
467461
except:
468462
raise Exception("Parent Apply node missing from output")
469463
assert len(self.expected_rv_op_params) == len(aesera_dist_inputs)
@@ -472,26 +466,18 @@ def _check_pymc_params_match_rv_op(self) -> None:
472466
):
473467
assert_almost_equal(expected_value, actual_variable.eval(), decimal=self.decimal)
474468

475-
def _check_distribution_size(self):
469+
def _check_rv_size(self):
476470
# test sizes
477471
sizes_to_check = self.sizes_to_check or [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
478-
sizes_expected = self.sizes_expected or [
479-
(),
480-
(),
481-
(1,),
482-
(1,),
483-
(5,),
484-
(4, 5),
485-
(2, 4, 2),
486-
]
472+
sizes_expected = self.sizes_expected or [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
487473
for size, expected in zip(sizes_to_check, sizes_expected):
488-
actual = change_rv_size(self.pymc_dist_output, size).eval().shape
474+
actual = change_rv_size(self.pymc_rv, size).eval().shape
489475
assert actual == expected
490476

491477
# test negative sizes raise
492478
for size in [-2, (3, -2)]:
493479
with pytest.raises(ValueError):
494-
change_rv_size(self.pymc_dist_output, size).eval()
480+
change_rv_size(self.pymc_rv, size).eval()
495481

496482
# test multi-parameters sampling for univariate distributions
497483
if self.pymc_dist.rv_op.ndim_supp == 0:
@@ -506,7 +492,7 @@ def _check_distribution_size(self):
506492
(5, self.repeated_params_shape),
507493
]
508494
for size, expected in zip(sizes_to_check, sizes_expected):
509-
actual = change_rv_size(self.pymc_dist_output, size).eval().shape
495+
actual = change_rv_size(self.pymc_rv, size).eval().shape
510496
assert actual == expected
511497

512498

@@ -526,27 +512,27 @@ class TestGumbelDistribution(BaseTestDistribution):
526512
pymc_dist = pm.Gumbel
527513
pymc_dist_params = {"mu": 1.5, "beta": 3.0}
528514
expected_rv_op_params = {"mu": 1.5, "beta": 3.0}
529-
expected_dist_params = {"loc": 1.5, "scale": 3.0}
515+
reference_dist_params = {"loc": 1.5, "scale": 3.0}
530516
size = 15
531-
expected_dist = seeded_scipy_distribution_builder("gumbel_r")
517+
reference_dist = seeded_scipy_distribution_builder("gumbel_r")
532518
tests_to_run = [
533519
"check_pymc_params_match_rv_op",
534-
"check_distribution_size",
535-
"check_pymc_dist_matches_expected",
520+
"check_rv_size",
521+
"check_pymc_dist_matches_reference",
536522
]
537523

538524

539525
class TestNormalDistribution(BaseTestDistribution):
540526
pymc_dist = pm.Normal
541527
pymc_dist_params = {"mu": 5.0, "sigma": 10.0}
542528
expected_rv_op_params = {"mu": 5.0, "sigma": 10.0}
543-
expected_dist_params = {"loc": 5.0, "scale": 10.0}
529+
reference_dist_params = {"loc": 5.0, "scale": 10.0}
544530
size = 15
545-
expected_dist = seeded_numpy_distribution_builder("normal")
531+
reference_dist = seeded_numpy_distribution_builder("normal")
546532
tests_to_run = [
547533
"check_pymc_params_match_rv_op",
548-
"check_distribution_size",
549-
"check_pymc_dist_matches_expected",
534+
"check_rv_size",
535+
"check_pymc_dist_matches_reference",
550536
]
551537

552538

@@ -718,7 +704,7 @@ class TestPoissonDistribution(BaseTestDistribution):
718704
tests_to_run = ["check_pymc_params_match_rv_op"]
719705

720706

721-
class TestMvNormalDistributionDistribution(BaseTestDistribution):
707+
class TestMvNormalDistribution(BaseTestDistribution):
722708
pymc_dist = pm.MvNormal
723709
pymc_dist_params = {
724710
"mu": np.array([1.0, 2.0]),
@@ -730,10 +716,10 @@ class TestMvNormalDistributionDistribution(BaseTestDistribution):
730716
}
731717
sizes_to_check = [None, (1), (2, 3)]
732718
sizes_expected = [(2,), (1, 2), (2, 3, 2)]
733-
tests_to_run = ["check_pymc_params_match_rv_op", "check_distribution_size"]
719+
tests_to_run = ["check_pymc_params_match_rv_op", "check_rv_size"]
734720

735721

736-
class TestMvNormalDistributionCholDistribution(BaseTestDistribution):
722+
class TestMvNormalDistributionChol(BaseTestDistribution):
737723
pymc_dist = pm.MvNormal
738724
pymc_dist_params = {
739725
"mu": np.array([1.0, 2.0]),
@@ -746,7 +732,7 @@ class TestMvNormalDistributionCholDistribution(BaseTestDistribution):
746732
tests_to_run = ["check_pymc_params_match_rv_op"]
747733

748734

749-
class TestMvNormalDistributionTauDistribution(BaseTestDistribution):
735+
class TestMvNormalDistributionTau(BaseTestDistribution):
750736
pymc_dist = pm.MvNormal
751737
pymc_dist_params = {
752738
"mu": np.array([1.0, 2.0]),
@@ -772,7 +758,7 @@ class TestMultinomialDistribution(BaseTestDistribution):
772758
expected_rv_op_params = {"n": 85, "p": np.array([0.28, 0.62, 0.10])}
773759
sizes_to_check = [None, (1), (4,), (3, 2)]
774760
sizes_expected = [(3,), (1, 3), (4, 3), (3, 2, 3)]
775-
tests_to_run = ["check_pymc_params_match_rv_op", "check_distribution_size"]
761+
tests_to_run = ["check_pymc_params_match_rv_op", "check_rv_size"]
776762

777763

778764
class TestCategoricalDistribution(BaseTestDistribution):

0 commit comments

Comments
 (0)