@@ -431,13 +431,7 @@ class BaseTestDistribution(SeededTest):
431
431
decimal = 6
432
432
433
433
def test_distribution (self ) -> None :
434
- with pm .Model ():
435
- self .pymc_dist_output = self .pymc_dist (
436
- ** self .pymc_dist_params ,
437
- size = self .size ,
438
- rng = aesara .shared (self .get_random_state ()),
439
- name = f"{ self .pymc_dist .rv_op .name } _test" ,
440
- )
434
+ self ._instantiate_pymc_distribution ()
441
435
if self .expected_dist is not None :
442
436
self .expected_dist_outcome = self .expected_dist ()(
443
437
** self .expected_dist_params , size = self .size
@@ -449,11 +443,23 @@ def run_test(self, test_name):
449
443
{
450
444
"check_pymc_dist_matches_expected" : self ._check_pymc_draws_match_expected ,
451
445
"check_pymc_params_match_rv_op" : self ._check_pymc_params_match_rv_op ,
446
+ "check_distribution_size" : self ._check_distribution_size ,
452
447
}[test_name ]()
453
448
449
+ def _instantiate_pymc_distribution (self ):
450
+ with pm .Model ():
451
+ self .pymc_dist_output = self .pymc_dist (
452
+ ** self .pymc_dist_params ,
453
+ size = self .size ,
454
+ rng = aesara .shared (self .get_random_state (reset = True )),
455
+ name = f"{ self .pymc_dist .rv_op .name } _test" ,
456
+ )
457
+
454
458
def _check_pymc_draws_match_expected (
455
459
self ,
456
460
):
461
+ # need to re-instantiate it to make sure that the order of drawings match the reference distribution one
462
+ self ._instantiate_pymc_distribution ()
457
463
assert_array_almost_equal (
458
464
self .pymc_dist_output .eval (), self .expected_dist_outcome , decimal = self .decimal
459
465
)
@@ -469,6 +475,28 @@ def _check_pymc_params_match_rv_op(self) -> None:
469
475
):
470
476
assert_almost_equal (expected_value , actual_variable .eval (), decimal = self .decimal )
471
477
478
+ def _check_distribution_size (self ):
479
+ sizes_to_check , sizes_expected = [None , (), 1 , (1 ,), 5 , (4 , 5 ), (2 , 4 , 2 )], [
480
+ (),
481
+ (),
482
+ (1 ,),
483
+ (1 ,),
484
+ (5 ,),
485
+ (4 , 5 ),
486
+ (2 , 4 , 2 ),
487
+ ]
488
+ for size , expected in zip (sizes_to_check , sizes_expected ):
489
+ pymc_dist_output_resized = change_rv_size (self .pymc_dist_output , size )
490
+ actual = pymc_dist_output_resized .eval ().shape
491
+ print (actual , expected )
492
+ assert actual == expected
493
+
494
+ # test negative sizes raise
495
+ with pytest .raises (ValueError ):
496
+ change_rv_size (self .pymc_dist_output , - 2 ).eval ()
497
+ with pytest .raises (ValueError ):
498
+ change_rv_size (self .pymc_dist_output , (3 , - 2 )).eval ()
499
+
472
500
473
501
def seeded_scipy_distribution_builder (dist_name : str ) -> Callable :
474
502
return lambda self : functools .partial (
@@ -489,7 +517,11 @@ class TestGumbelDistribution(BaseTestDistribution):
489
517
expected_dist_params = {"loc" : 1.5 , "scale" : 3.0 }
490
518
size = 15
491
519
expected_dist = seeded_scipy_distribution_builder ("gumbel_r" )
492
- tests_to_run = ["check_pymc_params_match_rv_op" , "check_pymc_dist_matches_expected" ]
520
+ tests_to_run = [
521
+ "check_pymc_params_match_rv_op" ,
522
+ "check_distribution_size" ,
523
+ "check_pymc_dist_matches_expected" ,
524
+ ]
493
525
494
526
495
527
class TestNormalDistribution (BaseTestDistribution ):
@@ -499,7 +531,11 @@ class TestNormalDistribution(BaseTestDistribution):
499
531
expected_dist_params = {"loc" : 5.0 , "scale" : 10.0 }
500
532
size = 15
501
533
expected_dist = seeded_numpy_distribution_builder ("normal" )
502
- tests_to_run = ["check_pymc_params_match_rv_op" , "check_pymc_dist_matches_expected" ]
534
+ tests_to_run = [
535
+ "check_pymc_params_match_rv_op" ,
536
+ "check_distribution_size" ,
537
+ "check_pymc_dist_matches_expected" ,
538
+ ]
503
539
504
540
505
541
class TestNormalTauDistribution (BaseTestDistribution ):
0 commit comments