@@ -637,7 +637,6 @@ def test_normal_scalar(self):
637
637
trace = pm .sample (
638
638
draws = ndraws ,
639
639
chains = nchains ,
640
- return_inferencedata = False ,
641
640
)
642
641
643
642
with model :
@@ -656,11 +655,13 @@ def test_normal_scalar(self):
656
655
assert ppc ["a" ].shape == (nchains , ndraws )
657
656
658
657
# test default case
659
- ppc = pm .sample_posterior_predictive (trace , var_names = ["a" ], return_inferencedata = False )
658
+ ppc = pm .sample_posterior_predictive (trace , var_names = ["a" ])
660
659
assert "a" in ppc
661
660
assert ppc ["a" ].shape == (nchains , ndraws )
662
661
# mu's standard deviation may have changed thanks to a's observed
663
- _ , pval = stats .kstest (ppc ["a" ] - trace ["mu" ], stats .norm (loc = 0 , scale = 1 ).cdf )
662
+ _ , pval = stats .kstest (
663
+ (ppc ["a" ] - trace .posterior ["mu" ]).values .flatten (), stats .norm (loc = 0 , scale = 1 ).cdf
664
+ )
664
665
assert pval > 0.001
665
666
666
667
def test_normal_scalar_idata (self ):
@@ -754,7 +755,7 @@ def test_sum_normal(self):
754
755
1000 ,
755
756
)
756
757
scale = np .sqrt (1 + 0.2 ** 2 )
757
- _ , pval = stats .kstest (ppc ["b" ], stats .norm (scale = scale ).cdf )
758
+ _ , pval = stats .kstest (ppc ["b" ]. flatten () , stats .norm (scale = scale ).cdf )
758
759
assert pval > 0.001
759
760
760
761
def test_model_not_drawable_prior (self ):
0 commit comments