Skip to content

Commit c343010

Browse files
committed
try fixing some tests
1 parent 57b8f7b commit c343010

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pymc/tests/test_sampling.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,6 @@ def test_normal_scalar(self):
637637
trace = pm.sample(
638638
draws=ndraws,
639639
chains=nchains,
640-
return_inferencedata=False,
641640
)
642641

643642
with model:
@@ -656,11 +655,13 @@ def test_normal_scalar(self):
656655
assert ppc["a"].shape == (nchains, ndraws)
657656

658657
# test default case
659-
ppc = pm.sample_posterior_predictive(trace, var_names=["a"], return_inferencedata=False)
658+
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
660659
assert "a" in ppc
661660
assert ppc["a"].shape == (nchains, ndraws)
662661
# mu's standard deviation may have changed thanks to a's observed
663-
_, pval = stats.kstest(ppc["a"] - trace["mu"], stats.norm(loc=0, scale=1).cdf)
662+
_, pval = stats.kstest(
663+
(ppc["a"] - trace.posterior["mu"]).values.flatten(), stats.norm(loc=0, scale=1).cdf
664+
)
664665
assert pval > 0.001
665666

666667
def test_normal_scalar_idata(self):
@@ -754,7 +755,7 @@ def test_sum_normal(self):
754755
1000,
755756
)
756757
scale = np.sqrt(1 + 0.2**2)
757-
_, pval = stats.kstest(ppc["b"], stats.norm(scale=scale).cdf)
758+
_, pval = stats.kstest(ppc["b"].flatten(), stats.norm(scale=scale).cdf)
758759
assert pval > 0.001
759760

760761
def test_model_not_drawable_prior(self):

0 commit comments

Comments
 (0)