@@ -337,12 +337,6 @@ class TestVonMises(BaseTestCases.BaseTestCase):
337
337
params = {"mu" : 0.0 , "kappa" : 1.0 }
338
338
339
339
340
- @pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
341
- class TestGumbel (BaseTestCases .BaseTestCase ):
342
- distribution = pm .Gumbel
343
- params = {"mu" : 0.0 , "beta" : 1.0 }
344
-
345
-
346
340
@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
347
341
class TestLogistic (BaseTestCases .BaseTestCase ):
348
342
distribution = pm .Logistic
@@ -417,8 +411,8 @@ class TestMoyal(BaseTestCases.BaseTestCase):
417
411
class BaseTestDistribution (SeededTest ):
418
412
pymc_dist : Optional [Callable ] = None
419
413
pymc_dist_params = dict ()
420
- expected_dist : Optional [Callable ] = None
421
- expected_dist_params = dict ()
414
+ reference_dist : Optional [Callable ] = None
415
+ reference_dist_params = dict ()
422
416
expected_rv_op_params = dict ()
423
417
tests_to_run = []
424
418
size = 15
@@ -430,40 +424,40 @@ class BaseTestDistribution(SeededTest):
430
424
431
425
def test_distribution (self ):
432
426
self ._instantiate_pymc_rv ()
433
- if self .expected_dist is not None :
434
- self .expected_dist_outcome = self .expected_dist ()(
435
- ** self .expected_dist_params , size = self .size
427
+ if self .reference_dist is not None :
428
+ self .reference_dist_draws = self .reference_dist ()(
429
+ ** self .reference_dist_params , size = self .size
436
430
)
437
431
for test_name in self .tests_to_run :
438
432
self .run_test (test_name )
439
433
440
434
def run_test (self , test_name ):
441
435
{
442
- "check_pymc_dist_matches_expected " : self ._check_pymc_draws_match_expected ,
436
+ "check_pymc_dist_matches_reference " : self ._check_pymc_draws_match_reference ,
443
437
"check_pymc_params_match_rv_op" : self ._check_pymc_params_match_rv_op ,
444
- "check_distribution_size " : self ._check_distribution_size ,
438
+ "check_rv_size " : self ._check_rv_size ,
445
439
}[test_name ]()
446
440
447
441
def _instantiate_pymc_rv (self , dist_params = None ):
448
442
params = dist_params if dist_params else self .pymc_dist_params
449
443
with pm .Model ():
450
- self .pymc_dist_output = self .pymc_dist (
444
+ self .pymc_rv = self .pymc_dist (
451
445
** params ,
452
446
size = self .size ,
453
447
rng = aesara .shared (self .get_random_state (reset = True )),
454
448
name = f"{ self .pymc_dist .rv_op .name } _test" ,
455
449
)
456
450
457
- def _check_pymc_draws_match_expected (self ):
451
+ def _check_pymc_draws_match_reference (self ):
458
452
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
459
453
self ._instantiate_pymc_rv ()
460
454
assert_array_almost_equal (
461
- self .pymc_dist_output .eval (), self .expected_dist_outcome , decimal = self .decimal
455
+ self .pymc_rv .eval (), self .reference_dist_draws , decimal = self .decimal
462
456
)
463
457
464
458
def _check_pymc_params_match_rv_op (self ) -> None :
465
459
try :
466
- aesera_dist_inputs = self .pymc_dist_output .get_parents ()[0 ].inputs [3 :]
460
+ aesera_dist_inputs = self .pymc_rv .get_parents ()[0 ].inputs [3 :]
467
461
except :
468
462
raise Exception ("Parent Apply node missing from output" )
469
463
assert len (self .expected_rv_op_params ) == len (aesera_dist_inputs )
@@ -472,26 +466,18 @@ def _check_pymc_params_match_rv_op(self) -> None:
472
466
):
473
467
assert_almost_equal (expected_value , actual_variable .eval (), decimal = self .decimal )
474
468
475
- def _check_distribution_size (self ):
469
+ def _check_rv_size (self ):
476
470
# test sizes
477
471
sizes_to_check = self .sizes_to_check or [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )]
478
- sizes_expected = self .sizes_expected or [
479
- (),
480
- (),
481
- (1 ,),
482
- (1 ,),
483
- (5 ,),
484
- (4 , 5 ),
485
- (2 , 4 , 2 ),
486
- ]
472
+ sizes_expected = self .sizes_expected or [(), (), (1 ,), (1 ,), (5 ,), (4 , 5 ), (2 , 4 , 2 )]
487
473
for size , expected in zip (sizes_to_check , sizes_expected ):
488
- actual = change_rv_size (self .pymc_dist_output , size ).eval ().shape
474
+ actual = change_rv_size (self .pymc_rv , size ).eval ().shape
489
475
assert actual == expected
490
476
491
477
# test negative sizes raise
492
478
for size in [- 2 , (3 , - 2 )]:
493
479
with pytest .raises (ValueError ):
494
- change_rv_size (self .pymc_dist_output , size ).eval ()
480
+ change_rv_size (self .pymc_rv , size ).eval ()
495
481
496
482
# test multi-parameters sampling for univariate distributions
497
483
if self .pymc_dist .rv_op .ndim_supp == 0 :
@@ -506,7 +492,7 @@ def _check_distribution_size(self):
506
492
(5 , self .repeated_params_shape ),
507
493
]
508
494
for size , expected in zip (sizes_to_check , sizes_expected ):
509
- actual = change_rv_size (self .pymc_dist_output , size ).eval ().shape
495
+ actual = change_rv_size (self .pymc_rv , size ).eval ().shape
510
496
assert actual == expected
511
497
512
498
@@ -526,27 +512,27 @@ class TestGumbelDistribution(BaseTestDistribution):
526
512
pymc_dist = pm .Gumbel
527
513
pymc_dist_params = {"mu" : 1.5 , "beta" : 3.0 }
528
514
expected_rv_op_params = {"mu" : 1.5 , "beta" : 3.0 }
529
- expected_dist_params = {"loc" : 1.5 , "scale" : 3.0 }
515
+ reference_dist_params = {"loc" : 1.5 , "scale" : 3.0 }
530
516
size = 15
531
- expected_dist = seeded_scipy_distribution_builder ("gumbel_r" )
517
+ reference_dist = seeded_scipy_distribution_builder ("gumbel_r" )
532
518
tests_to_run = [
533
519
"check_pymc_params_match_rv_op" ,
534
- "check_distribution_size " ,
535
- "check_pymc_dist_matches_expected " ,
520
+ "check_rv_size " ,
521
+ "check_pymc_dist_matches_reference " ,
536
522
]
537
523
538
524
539
525
class TestNormalDistribution (BaseTestDistribution ):
540
526
pymc_dist = pm .Normal
541
527
pymc_dist_params = {"mu" : 5.0 , "sigma" : 10.0 }
542
528
expected_rv_op_params = {"mu" : 5.0 , "sigma" : 10.0 }
543
- expected_dist_params = {"loc" : 5.0 , "scale" : 10.0 }
529
+ reference_dist_params = {"loc" : 5.0 , "scale" : 10.0 }
544
530
size = 15
545
- expected_dist = seeded_numpy_distribution_builder ("normal" )
531
+ reference_dist = seeded_numpy_distribution_builder ("normal" )
546
532
tests_to_run = [
547
533
"check_pymc_params_match_rv_op" ,
548
- "check_distribution_size " ,
549
- "check_pymc_dist_matches_expected " ,
534
+ "check_rv_size " ,
535
+ "check_pymc_dist_matches_reference " ,
550
536
]
551
537
552
538
@@ -718,7 +704,7 @@ class TestPoissonDistribution(BaseTestDistribution):
718
704
tests_to_run = ["check_pymc_params_match_rv_op" ]
719
705
720
706
721
- class TestMvNormalDistributionDistribution (BaseTestDistribution ):
707
+ class TestMvNormalDistribution (BaseTestDistribution ):
722
708
pymc_dist = pm .MvNormal
723
709
pymc_dist_params = {
724
710
"mu" : np .array ([1.0 , 2.0 ]),
@@ -730,10 +716,10 @@ class TestMvNormalDistributionDistribution(BaseTestDistribution):
730
716
}
731
717
sizes_to_check = [None , (1 ), (2 , 3 )]
732
718
sizes_expected = [(2 ,), (1 , 2 ), (2 , 3 , 2 )]
733
- tests_to_run = ["check_pymc_params_match_rv_op" , "check_distribution_size " ]
719
+ tests_to_run = ["check_pymc_params_match_rv_op" , "check_rv_size " ]
734
720
735
721
736
- class TestMvNormalDistributionCholDistribution (BaseTestDistribution ):
722
+ class TestMvNormalDistributionChol (BaseTestDistribution ):
737
723
pymc_dist = pm .MvNormal
738
724
pymc_dist_params = {
739
725
"mu" : np .array ([1.0 , 2.0 ]),
@@ -746,7 +732,7 @@ class TestMvNormalDistributionCholDistribution(BaseTestDistribution):
746
732
tests_to_run = ["check_pymc_params_match_rv_op" ]
747
733
748
734
749
- class TestMvNormalDistributionTauDistribution (BaseTestDistribution ):
735
+ class TestMvNormalDistributionTau (BaseTestDistribution ):
750
736
pymc_dist = pm .MvNormal
751
737
pymc_dist_params = {
752
738
"mu" : np .array ([1.0 , 2.0 ]),
@@ -772,7 +758,7 @@ class TestMultinomialDistribution(BaseTestDistribution):
772
758
expected_rv_op_params = {"n" : 85 , "p" : np .array ([0.28 , 0.62 , 0.10 ])}
773
759
sizes_to_check = [None , (1 ), (4 ,), (3 , 2 )]
774
760
sizes_expected = [(3 ,), (1 , 3 ), (4 , 3 ), (3 , 2 , 3 )]
775
- tests_to_run = ["check_pymc_params_match_rv_op" , "check_distribution_size " ]
761
+ tests_to_run = ["check_pymc_params_match_rv_op" , "check_rv_size " ]
776
762
777
763
778
764
class TestCategoricalDistribution (BaseTestDistribution ):
0 commit comments