Skip to content

Commit cf662c9

Browse files
authored
proper handling of chain_idx in sample() (#4495)
* don't add chain offset in loop already taking care of this * don't default to chain 0 when computing sampler stats; use provided chain_idx * adding test script for chain_idx in sample() * marking test as xfail for now * take care of chain indices in sample_posterior_predictive * update test to include sample_posterior_predictive * use reproducable order of ppc samples wrt multiple chains
1 parent ba186e3 commit cf662c9

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

pymc3/sampling.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def sample(
600600
# count the number of tune/draw iterations that happened
601601
# ideally via the "tune" statistic, but not all samplers record it!
602602
if "tune" in trace.stat_names:
603-
stat = trace.get_sampler_stats("tune", chains=0)
603+
stat = trace.get_sampler_stats("tune", chains=chain_idx)
604604
# when CompoundStep is used, the stat is 2 dimensional!
605605
if len(stat.shape) == 2:
606606
stat = stat[:, 0]
@@ -1453,9 +1453,9 @@ def _mp_sample(
14531453
# dict does not contain all parameters
14541454
update_start_vals(start[idx - chain], model.test_point, model)
14551455
if step.generates_stats and strace.supports_sampler_stats:
1456-
strace.setup(draws + tune, idx + chain, step.stats_dtypes)
1456+
strace.setup(draws + tune, idx, step.stats_dtypes)
14571457
else:
1458-
strace.setup(draws + tune, idx + chain)
1458+
strace.setup(draws + tune, idx)
14591459
traces.append(strace)
14601460

14611461
sampler = ps.ParallelSampler(
@@ -1716,12 +1716,19 @@ def sample_posterior_predictive(
17161716

17171717
ppc_trace_t = _DefaultTrace(samples)
17181718
try:
1719+
if hasattr(_trace, "_straces"):
1720+
# trace dict is unordered, but we want to return ppc samples in
1721+
# a predictable ordering, so sort the chain indices
1722+
chain_idx_mapping = sorted(_trace._straces.keys())
17191723
for idx in indices:
17201724
if nchain > 1:
17211725
# the trace object will either be a MultiTrace (and have _straces)...
17221726
if hasattr(_trace, "_straces"):
17231727
chain_idx, point_idx = np.divmod(idx, len_trace)
1724-
param = cast(MultiTrace, _trace)._straces[chain_idx % nchain].point(point_idx)
1728+
chain_idx = chain_idx % nchain
1729+
# chain indices might not always start at 0, convert to proper index
1730+
chain_idx = chain_idx_mapping[chain_idx]
1731+
param = cast(MultiTrace, _trace)._straces[chain_idx].point(point_idx)
17251732
# ... or a PointList
17261733
else:
17271734
param = cast(PointList, _trace)[idx % (len_trace * nchain)]

pymc3/tests/test_sampling.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,21 @@ def test_partial_trace_sample():
321321
trace = pm.sample(trace=[a])
322322

323323

324+
@pytest.mark.xfail
325+
def test_chain_idx():
326+
# see https://github.com/pymc-devs/pymc3/issues/4469
327+
with pm.Model():
328+
mu = pm.Normal("mu")
329+
x = pm.Normal("x", mu=mu, sigma=1, observed=np.asarray(3))
330+
# note draws-tune must be >100 AND we need an observed RV for this to properly
331+
# trigger convergence checks, which is one particular case in which this failed
332+
# before
333+
trace = pm.sample(draws=150, tune=10, chain_idx=1)
334+
335+
ppc = pm.sample_posterior_predictive(trace)
336+
ppc = pm.sample_posterior_predictive(trace, keep_size=True)
337+
338+
324339
@pytest.mark.parametrize(
325340
"n_points, tune, expected_length, expected_n_traces",
326341
[

0 commit comments

Comments
 (0)