Skip to content

Commit 65f69d6

Browse files
committed
add docs
1 parent 9130d81 commit 65f69d6

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

pymc/sampling.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,10 @@ def sample_posterior_predictive(
17911791
generally be the model used to generate the ``trace``, but it doesn't need to be.
17921792
var_names : Iterable[str]
17931793
Names of variables for which to compute the posterior predictive samples.
1794+
sample_dims : list of str, optional
1795+
Dimensions over which to loop and generate posterior predictive samples.
1796+
When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
1797+
dimensions. Only taken into account when `trace` is InferenceData or Dataset.
17941798
random_seed : int, RandomState or Generator, optional
17951799
Seed for the random number generator.
17961800
progressbar : bool
@@ -1827,6 +1831,14 @@ def sample_posterior_predictive(
18271831
thinned_idata = idata.sel(draw=slice(None, None, 5))
18281832
with model:
18291833
idata.extend(pymc.sample_posterior_predictive(thinned_idata))
1834+
1835+
Generate 5 posterior predictive samples per posterior sample.
1836+
1837+
.. code:: python
1838+
1839+
expanded_data = idata.posterior.expand_dims(pred_id=5)
1840+
with model:
1841+
idata.extend(pymc.sample_posterior_predictive(expanded_data))
18301842
"""
18311843

18321844
_trace: Union[MultiTrace, PointList]

pymc/tests/test_sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,14 +1625,14 @@ def test_sample_dims(self, point_list_arg_bug_fixture):
16251625
pp = pm.sample_posterior_predictive(post, var_names=["d"], sample_dims=["sample"])
16261626
assert "sample" in pp.posterior_predictive
16271627
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
1628-
post = post.expand_dims(pp_dim=5)
1628+
post = post.expand_dims(pred_id=5)
16291629
pp = pm.sample_posterior_predictive(
1630-
post, var_names=["d"], sample_dims=["sample", "pp_dim"]
1630+
post, var_names=["d"], sample_dims=["sample", "pred_id"]
16311631
)
16321632
assert "sample" in pp.posterior_predictive
1633-
assert "pp_dim" in pp.posterior_predictive
1633+
assert "pred_id" in pp.posterior_predictive
16341634
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
1635-
assert len(pp.posterior_predictive["pp_dim"]) == 5
1635+
assert len(pp.posterior_predictive["pred_id"]) == 5
16361636

16371637

16381638
class TestDraw(SeededTest):

0 commit comments

Comments
 (0)