Skip to content

Commit c8a3c95

Browse files
committed
try fixing tests
1 parent c343010 commit c8a3c95

File tree

2 files changed

+5
-18
lines changed

2 files changed

+5
-18
lines changed

pymc/tests/test_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,8 +1263,9 @@ def test_interval_missing_observations():
12631263

12641264
# Make sure that the observed values are newly generated samples and that
12651265
# the observed and deterministic matche
1266-
pp_trace = pm.sample_posterior_predictive(
1267-
trace, return_inferencedata=False, keep_size=False
1266+
pp_idata = pm.sample_posterior_predictive(trace)
1267+
pp_trace = pp_idata.posterior_predictive.stack(sample=["chain", "draw"]).transpose(
1268+
"sample", ...
12681269
)
12691270
assert np.all(np.var(pp_trace["theta1"], 0) > 0.0)
12701271
assert np.all(np.var(pp_trace["theta2"], 0) > 0.0)

pymc/tests/test_sampling.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -505,21 +505,6 @@ def test_partial_trace_sample():
505505
assert "b" not in idata.posterior
506506

507507

508-
def test_chain_idx():
509-
# see https://github.com/pymc-devs/pymc/issues/4469
510-
with pm.Model():
511-
mu = pm.Normal("mu")
512-
x = pm.Normal("x", mu=mu, sigma=1, observed=np.asarray(3))
513-
# note draws-tune must be >100 AND we need an observed RV for this to properly
514-
# trigger convergence checks, which is one particular case in which this failed
515-
# before
516-
idata = pm.sample(draws=150, tune=10, chain_idx=1)
517-
518-
ppc = pm.sample_posterior_predictive(idata)
519-
# TODO FIXME: Assert something.
520-
ppc = pm.sample_posterior_predictive(idata)
521-
522-
523508
@pytest.mark.parametrize(
524509
"n_points, tune, expected_length, expected_n_traces",
525510
[
@@ -655,7 +640,8 @@ def test_normal_scalar(self):
655640
assert ppc["a"].shape == (nchains, ndraws)
656641

657642
# test default case
658-
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
643+
idata_ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
644+
ppc = idata_ppc.posterior_predictive
659645
assert "a" in ppc
660646
assert ppc["a"].shape == (nchains, ndraws)
661647
# mu's standard deviation may have changed thanks to a's observed

0 commit comments

Comments
 (0)