16
16
import sys
17
17
18
18
from contextlib import ExitStack as does_not_raise
19
- from typing import Callable
19
+ from typing import Callable , List , Optional
20
20
21
21
import aesara
22
22
import numpy as np
@@ -421,17 +421,21 @@ class TestMoyal(BaseTestCases.BaseTestCase):
421
421
422
422
423
423
class BaseTestDistribution (SeededTest ):
424
- pymc_dist = None
424
+ pymc_dist : Optional [ Callable ] = None
425
425
pymc_dist_params = dict ()
426
- expected_dist = None
426
+ expected_dist : Optional [ Callable ] = None
427
427
expected_dist_params = dict ()
428
428
expected_rv_op_params = dict ()
429
429
tests_to_run = []
430
430
size = 15
431
431
decimal = 6
432
432
433
- def test_distribution (self ) -> None :
434
- self ._instantiate_pymc_distribution ()
433
+ sizes_to_check : Optional [List ] = None
434
+ sizes_expected : Optional [List ] = None
435
+ repeated_params_shape = 5
436
+
437
+ def test_distribution (self ):
438
+ self ._instantiate_pymc_rv ()
435
439
if self .expected_dist is not None :
436
440
self .expected_dist_outcome = self .expected_dist ()(
437
441
** self .expected_dist_params , size = self .size
@@ -446,20 +450,19 @@ def run_test(self, test_name):
446
450
"check_distribution_size" : self ._check_distribution_size ,
447
451
}[test_name ]()
448
452
449
- def _instantiate_pymc_distribution (self ):
453
+ def _instantiate_pymc_rv (self , dist_params = None ):
454
+ params = dist_params if dist_params else self .pymc_dist_params
450
455
with pm .Model ():
451
456
self .pymc_dist_output = self .pymc_dist (
452
- ** self . pymc_dist_params ,
457
+ ** params ,
453
458
size = self .size ,
454
459
rng = aesara .shared (self .get_random_state (reset = True )),
455
460
name = f"{ self .pymc_dist .rv_op .name } _test" ,
456
461
)
457
462
458
- def _check_pymc_draws_match_expected (
459
- self ,
460
- ):
463
+ def _check_pymc_draws_match_expected (self ):
461
464
# need to re-instantiate it to make sure that the order of drawings match the reference distribution one
462
- self ._instantiate_pymc_distribution ()
465
+ self ._instantiate_pymc_rv ()
463
466
assert_array_almost_equal (
464
467
self .pymc_dist_output .eval (), self .expected_dist_outcome , decimal = self .decimal
465
468
)
@@ -476,7 +479,9 @@ def _check_pymc_params_match_rv_op(self) -> None:
476
479
assert_almost_equal (expected_value , actual_variable .eval (), decimal = self .decimal )
477
480
478
481
def _check_distribution_size (self ):
479
- sizes_to_check , sizes_expected = [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )], [
482
+ # test sizes
483
+ sizes_to_check = self .sizes_to_check or [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )]
484
+ sizes_expected = self .sizes_expected or [
480
485
(),
481
486
(),
482
487
(1 ,),
@@ -486,16 +491,29 @@ def _check_distribution_size(self):
486
491
(2 , 4 , 2 ),
487
492
]
488
493
for size , expected in zip (sizes_to_check , sizes_expected ):
489
- pymc_dist_output_resized = change_rv_size (self .pymc_dist_output , size )
490
- actual = pymc_dist_output_resized .eval ().shape
491
- print (actual , expected )
494
+ actual = change_rv_size (self .pymc_dist_output , size ).eval ().shape
492
495
assert actual == expected
493
496
494
497
# test negative sizes raise
495
- with pytest .raises (ValueError ):
496
- change_rv_size (self .pymc_dist_output , - 2 ).eval ()
497
- with pytest .raises (ValueError ):
498
- change_rv_size (self .pymc_dist_output , (3 , - 2 )).eval ()
498
+ for size in [- 2 , (3 , - 2 )]:
499
+ with pytest .raises (ValueError ):
500
+ change_rv_size (self .pymc_dist_output , size ).eval ()
501
+
502
+ # test multi-parameters sampling for univariate distributions
503
+ if self .pymc_dist .rv_op .ndim_supp == 0 :
504
+ params = {
505
+ k : p * np .ones (self .repeated_params_shape ) for k , p in self .pymc_dist_params .items ()
506
+ }
507
+ self ._instantiate_pymc_rv (params )
508
+ sizes_to_check = [None , self .repeated_params_shape , (5 , self .repeated_params_shape )]
509
+ sizes_expected = [
510
+ (self .repeated_params_shape ,),
511
+ (self .repeated_params_shape ,),
512
+ (5 , self .repeated_params_shape ),
513
+ ]
514
+ for size , expected in zip (sizes_to_check , sizes_expected ):
515
+ actual = change_rv_size (self .pymc_dist_output , size ).eval ().shape
516
+ assert actual == expected
499
517
500
518
501
519
def seeded_scipy_distribution_builder (dist_name : str ) -> Callable :
@@ -706,7 +724,7 @@ class TestPoissonDistribution(BaseTestDistribution):
706
724
tests_to_run = ["check_pymc_params_match_rv_op" ]
707
725
708
726
709
- class TestMVNormalDistributionDistribution (BaseTestDistribution ):
727
+ class TestMvNormalDistributionDistribution (BaseTestDistribution ):
710
728
pymc_dist = pm .MvNormal
711
729
pymc_dist_params = {
712
730
"mu" : np .array ([1.0 , 2.0 ]),
@@ -716,10 +734,12 @@ class TestMVNormalDistributionDistribution(BaseTestDistribution):
716
734
"mu" : np .array ([1.0 , 2.0 ]),
717
735
"cov" : np .array ([[2.0 , 0.0 ], [0.0 , 3.5 ]]),
718
736
}
719
- tests_to_run = ["check_pymc_params_match_rv_op" ]
737
+ sizes_to_check = [None , (1 ), (2 , 3 )]
738
+ sizes_expected = [(2 ,), (1 , 2 ), (2 , 3 , 2 )]
739
+ tests_to_run = ["check_pymc_params_match_rv_op" , "check_distribution_size" ]
720
740
721
741
722
- class TestMVNormalDistributionCholDistribution (BaseTestDistribution ):
742
+ class TestMvNormalDistributionCholDistribution (BaseTestDistribution ):
723
743
pymc_dist = pm .MvNormal
724
744
pymc_dist_params = {
725
745
"mu" : np .array ([1.0 , 2.0 ]),
@@ -732,7 +752,7 @@ class TestMVNormalDistributionCholDistribution(BaseTestDistribution):
732
752
tests_to_run = ["check_pymc_params_match_rv_op" ]
733
753
734
754
735
- class TestMVNormalDistributionTauDistribution (BaseTestDistribution ):
755
+ class TestMvNormalDistributionTauDistribution (BaseTestDistribution ):
736
756
pymc_dist = pm .MvNormal
737
757
pymc_dist_params = {
738
758
"mu" : np .array ([1.0 , 2.0 ]),
@@ -756,7 +776,9 @@ class TestMultinomialDistribution(BaseTestDistribution):
756
776
pymc_dist = pm .Multinomial
757
777
pymc_dist_params = {"n" : 85 , "p" : np .array ([0.28 , 0.62 , 0.10 ])}
758
778
expected_rv_op_params = {"n" : 85 , "p" : np .array ([0.28 , 0.62 , 0.10 ])}
759
- tests_to_run = ["check_pymc_params_match_rv_op" ]
779
+ sizes_to_check = [None , (1 ), (4 ,), (3 , 2 )]
780
+ sizes_expected = [(3 ,), (1 , 3 ), (4 , 3 ), (3 , 2 , 3 )]
781
+ tests_to_run = ["check_pymc_params_match_rv_op" , "check_distribution_size" ]
760
782
761
783
762
784
class TestCategoricalDistribution (BaseTestDistribution ):
0 commit comments