File tree Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Expand file tree Collapse file tree 2 files changed +16
-4
lines changed Original file line number Diff line number Diff line change @@ -1791,6 +1791,10 @@ def sample_posterior_predictive(
1791
1791
generally be the model used to generate the ``trace``, but it doesn't need to be.
1792
1792
var_names : Iterable[str]
1793
1793
Names of variables for which to compute the posterior predictive samples.
1794
+ sample_dims : list of str, optional
1795
+ Dimensions over which to loop and generate posterior predictive samples.
1796
+ When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
1797
+ dimensions. Only taken into account when `trace` is InferenceData or Dataset.
1794
1798
random_seed : int, RandomState or Generator, optional
1795
1799
Seed for the random number generator.
1796
1800
progressbar : bool
@@ -1827,6 +1831,14 @@ def sample_posterior_predictive(
1827
1831
thinned_idata = idata.sel(draw=slice(None, None, 5))
1828
1832
with model:
1829
1833
idata.extend(pymc.sample_posterior_predictive(thinned_idata))
1834
+
1835
+ Generate 5 posterior predictive samples per posterior sample.
1836
+
1837
+ .. code:: python
1838
+
1839
+ expanded_data = idata.posterior.expand_dims(pred_id=5)
1840
+ with model:
1841
+ idata.extend(pymc.sample_posterior_predictive(expanded_data))
1830
1842
"""
1831
1843
1832
1844
_trace : Union [MultiTrace , PointList ]
Original file line number Diff line number Diff line change @@ -1625,14 +1625,14 @@ def test_sample_dims(self, point_list_arg_bug_fixture):
1625
1625
pp = pm .sample_posterior_predictive (post , var_names = ["d" ], sample_dims = ["sample" ])
1626
1626
assert "sample" in pp .posterior_predictive
1627
1627
assert len (pp .posterior_predictive ["sample" ]) == len (post ["sample" ])
1628
- post = post .expand_dims (pp_dim = 5 )
1628
+ post = post .expand_dims (pred_id = 5 )
1629
1629
pp = pm .sample_posterior_predictive (
1630
- post , var_names = ["d" ], sample_dims = ["sample" , "pp_dim " ]
1630
+ post , var_names = ["d" ], sample_dims = ["sample" , "pred_id " ]
1631
1631
)
1632
1632
assert "sample" in pp .posterior_predictive
1633
- assert "pp_dim " in pp .posterior_predictive
1633
+ assert "pred_id " in pp .posterior_predictive
1634
1634
assert len (pp .posterior_predictive ["sample" ]) == len (post ["sample" ])
1635
- assert len (pp .posterior_predictive ["pp_dim " ]) == 5
1635
+ assert len (pp .posterior_predictive ["pred_id " ]) == 5
1636
1636
1637
1637
1638
1638
class TestDraw (SeededTest ):
You can’t perform that action at this time.
0 commit comments