Skip to content

Commit 5b2b7f0

Browse files
authored
Merge pull request #1440 from pymc-devs/plot_posterior_vector
Generalized plot_posterior to plot vector-valued variables
2 parents 90c7286 + 092faf2 commit 5b2b7f0

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

pymc3/plots.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -746,16 +746,28 @@ def set_key_if_doesnt_exist(d, key, value):
746746
if rope is not None:
747747
display_rope(rope)
748748

749-
def create_axes_grid(figsize, varnames):
750-
n = np.ceil(len(varnames) / 2.0).astype(int)
749+
def create_axes_grid(figsize, traces):
750+
n = np.ceil(len(traces) / 2.0).astype(int)
751751
if figsize is None:
752752
figsize = (12, n * 2.5)
753753
fig, ax = plt.subplots(n, 2, figsize=figsize)
754754
ax = ax.reshape(2 * n)
755-
if len(varnames) % 2 == 1:
755+
if len(traces) % 2 == 1:
756756
ax[-1].set_axis_off()
757757
ax = ax[:-1]
758758
return ax, fig
759+
760+
def get_trace_dict(tr, varnames):
761+
traces = {}
762+
for v in varnames:
763+
vals = tr.get_values(v, combine=True, squeeze=True)
764+
if vals.ndim>1:
765+
vals_flat = vals.reshape(vals.shape[0], -1).T
766+
for i,vi in enumerate(vals_flat):
767+
traces['_'.join([v,str(i)])] = vi
768+
else:
769+
traces[v] = vals
770+
return traces
759771

760772
if isinstance(trace, np.ndarray):
761773
if figsize is None:
@@ -770,12 +782,13 @@ def create_axes_grid(figsize, varnames):
770782
else:
771783
varnames = [name for name in trace.varnames if not name.endswith('_')]
772784

785+
trace_dict = get_trace_dict(trace, varnames)
786+
773787
if ax is None:
774-
ax, fig = create_axes_grid(figsize, varnames)
788+
ax, fig = create_axes_grid(figsize, trace_dict)
775789

776-
for a, v in zip(ax, varnames):
777-
tr_values = transform(trace.get_values(
778-
v, combine=True, squeeze=True))
790+
for a, v in zip(ax, trace_dict):
791+
tr_values = transform(trace_dict[v])
779792
plot_posterior_op(tr_values, ax=a)
780793
a.set_title(v)
781794

pymc3/tests/test_plots.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ def test_plots_multidimensional():
3737
trace = sample(3000, step, start)
3838

3939
traceplot(trace)
40-
# forestplot(trace)
41-
# autocorrplot(trace)
40+
plot_posterior(trace)
4241

4342

4443
def test_multichain_plots():

0 commit comments

Comments
 (0)