Skip to content

Commit 6fe51e5

Browse files
fonnesbecktwiecki
authored andcommitted
Added marker and line options for forestplot (#2220)
* Added marker and line options for forestplot * Added missing argument to plot * Replaced explicit plot options with plot_kwargs * Pushed plot_kwargs into _plot_tree
1 parent b81a9f7 commit 6fe51e5

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

pymc3/plots/forestplot.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
8080
return ax
8181

8282

83-
def _plot_tree(ax, y, ntiles, show_quartiles):
83+
def _plot_tree(ax, y, ntiles, show_quartiles, **plot_kwargs):
8484
"""Helper to plot errorbars for the forestplot.
8585
8686
Parameters
@@ -101,22 +101,32 @@ def _plot_tree(ax, y, ntiles, show_quartiles):
101101
"""
102102
if show_quartiles:
103103
# Plot median
104-
ax.plot(ntiles[2], y, 'bo', markersize=4)
104+
ax.plot(ntiles[2], y, color=plot_kwargs.get('color', 'blue'),
105+
marker=plot_kwargs.get('marker', 'o'),
106+
markersize=plot_kwargs.get('markersize', 4))
105107
# Plot quartile interval
106-
ax.errorbar(x=(ntiles[1], ntiles[3]), y=(y, y), linewidth=2, color='b')
108+
ax.errorbar(x=(ntiles[1], ntiles[3]), y=(y, y),
109+
linewidth=plot_kwargs.get('linewidth', 2),
110+
color=plot_kwargs.get('color', 'blue'))
107111

108112
else:
109113
# Plot median
110-
ax.plot(ntiles[1], y, 'bo', markersize=4)
114+
ax.plot(ntiles[1], y, marker=plot_kwargs.get('marker', 'o'),
115+
color=plot_kwargs.get('color', 'blue'),
116+
markersize=plot_kwargs.get('markersize', 4))
111117

112118
# Plot outer interval
113-
ax.errorbar(x=(ntiles[0], ntiles[-1]), y=(y, y), linewidth=1, color='b')
119+
ax.errorbar(x=(ntiles[0], ntiles[-1]), y=(y, y),
120+
linewidth=int(plot_kwargs.get('linewidth', 2)/2),
121+
color=plot_kwargs.get('color', 'blue'))
122+
114123
return ax
115124

116125

117126
def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.05, quartiles=True,
118127
rhat=True, main=None, xtitle=None, xlim=None, ylabels=None,
119-
chain_spacing=0.05, vline=0, gs=None, plot_transformed=False):
128+
chain_spacing=0.05, vline=0, gs=None, plot_transformed=False,
129+
**plot_kwargs):
120130
"""
121131
Forest plot (model summary plot).
122132
@@ -160,6 +170,9 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
160170
plot_transformed : bool
161171
Flag for plotting automatically transformed variables in addition to
162172
original variables (defaults to False).
173+
plot_kwargs : dict
174+
Optional arguments for plot elements. Currently accepts 'fontsize',
175+
'linewidth', 'color', 'marker', and 'markersize'.
163176
164177
Returns
165178
-------
@@ -242,10 +255,12 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
242255
if k > 1:
243256
for q in np.transpose(quants).squeeze():
244257
# Multiple y values
245-
interval_plot = _plot_tree(interval_plot, y, q, quartiles)
258+
interval_plot = _plot_tree(interval_plot, y, q, quartiles,
259+
**plot_kwargs)
246260
y -= 1
247261
else:
248-
interval_plot = _plot_tree(interval_plot, y, quants, quartiles)
262+
interval_plot = _plot_tree(interval_plot, y, quants, quartiles,
263+
**plot_kwargs)
249264

250265
# Increment index
251266
var += k
@@ -264,7 +279,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
264279

265280
# Add variable labels
266281
interval_plot.set_yticks([-l for l in range(len(labels))])
267-
interval_plot.set_yticklabels(labels)
282+
interval_plot.set_yticklabels(labels, fontsize=plot_kwargs.get('fontsize', None))
268283

269284
# Add title
270285
plot_title = ""
@@ -273,7 +288,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
273288
elif main:
274289
plot_title = main
275290
if plot_title:
276-
interval_plot.set_title(plot_title)
291+
interval_plot.set_title(plot_title, fontsize=plot_kwargs.get('fontsize', None))
277292

278293
# Add x-axis label
279294
if xtitle is not None:
@@ -293,7 +308,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
293308
spine.set_color('none') # don't draw spine
294309

295310
# Reference line
296-
interval_plot.axvline(vline, color='k', linestyle='--')
311+
interval_plot.axvline(vline, color='k', linestyle=':')
297312

298313
# Genenerate Gelman-Rubin plot
299314
if plot_rhat:

0 commit comments

Comments
 (0)