Skip to content

Commit 57b8f7b

Browse files
pibietaOriolAbril
authored andcommitted
edited test_mixture.py and test_sampling.py
1 parent 1550b8e commit 57b8f7b

File tree

2 files changed

+15
-15
lines changed

2 files changed

+15
-15
lines changed

pymc/tests/distributions/test_mixture.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -554,10 +554,10 @@ def test_single_poisson_predictive_sampling_shape(self):
554554
assert prior["like2"].shape == (n_samples, 20)
555555
assert prior["like3"].shape == (n_samples, 20)
556556

557-
assert ppc["like0"].shape == (n_samples, 20)
558-
assert ppc["like1"].shape == (n_samples, 20)
559-
assert ppc["like2"].shape == (n_samples, 20)
560-
assert ppc["like3"].shape == (n_samples, 20)
557+
assert ppc["like0"].shape == (1, n_samples, 20)
558+
assert ppc["like1"].shape == (1, n_samples, 20)
559+
assert ppc["like2"].shape == (1, n_samples, 20)
560+
assert ppc["like3"].shape == (1, n_samples, 20)
561561

562562
def test_list_mvnormals_predictive_sampling_shape(self):
563563
N = 100 # number of data points
@@ -594,7 +594,14 @@ def test_list_mvnormals_predictive_sampling_shape(self):
594594
ppc = sample_posterior_predictive(
595595
n_samples * [self.get_inital_point(model)], return_inferencedata=False
596596
)
597-
assert ppc["x_obs"].shape == (n_samples,) + X.shape
597+
assert (
598+
ppc["x_obs"].shape
599+
== (
600+
1,
601+
n_samples,
602+
)
603+
+ X.shape
604+
)
598605
assert prior["x_obs"].shape == (n_samples,) + X.shape
599606
assert prior["mu0"].shape == (n_samples, D)
600607
assert prior["chol_cov_0"].shape == (n_samples, D * (D + 1) // 2)

pymc/tests/test_sampling.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def test_normal_scalar(self):
658658
# test default case
659659
ppc = pm.sample_posterior_predictive(trace, var_names=["a"], return_inferencedata=False)
660660
assert "a" in ppc
661-
assert ppc["a"].shape == (nchains * ndraws,)
661+
assert ppc["a"].shape == (nchains, ndraws)
662662
# mu's standard deviation may have changed thanks to a's observed
663663
_, pval = stats.kstest(ppc["a"] - trace["mu"], stats.norm(loc=0, scale=1).cdf)
664664
assert pval > 0.001
@@ -804,7 +804,7 @@ def test_model_shared_variable(self):
804804
trace, return_inferencedata=False, var_names=["p", "obs"]
805805
)
806806

807-
expected_p = np.array([logistic.eval({coeff: val}) for val in trace["x"][:samples]])
807+
expected_p = np.array([[logistic.eval({coeff: val}) for val in trace["x"][:samples]]])
808808
assert post_pred["obs"].shape == (1, samples, 3)
809809
npt.assert_allclose(post_pred["p"], expected_p)
810810

@@ -1392,14 +1392,7 @@ def test_multivariate2(self):
13921392
)
13931393
assert sim_priors["probs"].shape == (20, 6)
13941394
assert sim_priors["obs"].shape == (20,) + mn_data.shape
1395-
assert (
1396-
sim_ppc["obs"].shape
1397-
== (
1398-
1,
1399-
20
1400-
)
1401-
+ mn_data.shape
1402-
)
1395+
assert sim_ppc["obs"].shape == (1, 20) + mn_data.shape
14031396

14041397
def test_layers(self):
14051398
with pm.Model() as model:

0 commit comments

Comments
 (0)