@@ -80,7 +80,7 @@ def _make_rhat_plot(trace, ax, title, labels, varnames, include_transformed):
80
80
return ax
81
81
82
82
83
- def _plot_tree (ax , y , ntiles , show_quartiles ):
83
+ def _plot_tree (ax , y , ntiles , show_quartiles , ** plot_kwargs ):
84
84
"""Helper to plot errorbars for the forestplot.
85
85
86
86
Parameters
@@ -101,22 +101,32 @@ def _plot_tree(ax, y, ntiles, show_quartiles):
101
101
"""
102
102
if show_quartiles :
103
103
# Plot median
104
- ax .plot (ntiles [2 ], y , 'bo' , markersize = 4 )
104
+ ax .plot (ntiles [2 ], y , color = plot_kwargs .get ('color' , 'blue' ),
105
+ marker = plot_kwargs .get ('marker' , 'o' ),
106
+ markersize = plot_kwargs .get ('markersize' , 4 ))
105
107
# Plot quartile interval
106
- ax .errorbar (x = (ntiles [1 ], ntiles [3 ]), y = (y , y ), linewidth = 2 , color = 'b' )
108
+ ax .errorbar (x = (ntiles [1 ], ntiles [3 ]), y = (y , y ),
109
+ linewidth = plot_kwargs .get ('linewidth' , 2 ),
110
+ color = plot_kwargs .get ('color' , 'blue' ))
107
111
108
112
else :
109
113
# Plot median
110
- ax .plot (ntiles [1 ], y , 'bo' , markersize = 4 )
114
+ ax .plot (ntiles [1 ], y , marker = plot_kwargs .get ('marker' , 'o' ),
115
+ color = plot_kwargs .get ('color' , 'blue' ),
116
+ markersize = plot_kwargs .get ('markersize' , 4 ))
111
117
112
118
# Plot outer interval
113
- ax .errorbar (x = (ntiles [0 ], ntiles [- 1 ]), y = (y , y ), linewidth = 1 , color = 'b' )
119
+ ax .errorbar (x = (ntiles [0 ], ntiles [- 1 ]), y = (y , y ),
120
+ linewidth = int (plot_kwargs .get ('linewidth' , 2 )/ 2 ),
121
+ color = plot_kwargs .get ('color' , 'blue' ))
122
+
114
123
return ax
115
124
116
125
117
126
def forestplot (trace_obj , varnames = None , transform = identity_transform , alpha = 0.05 , quartiles = True ,
118
127
rhat = True , main = None , xtitle = None , xlim = None , ylabels = None ,
119
- chain_spacing = 0.05 , vline = 0 , gs = None , plot_transformed = False ):
128
+ chain_spacing = 0.05 , vline = 0 , gs = None , plot_transformed = False ,
129
+ ** plot_kwargs ):
120
130
"""
121
131
Forest plot (model summary plot).
122
132
@@ -160,6 +170,9 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
160
170
plot_transformed : bool
161
171
Flag for plotting automatically transformed variables in addition to
162
172
original variables (defaults to False).
173
+ plot_kwargs : dict
174
+ Optional arguments for plot elements. Currently accepts 'fontsize',
175
+ 'linewidth', 'color', 'marker', and 'markersize'.
163
176
164
177
Returns
165
178
-------
@@ -242,10 +255,12 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
242
255
if k > 1 :
243
256
for q in np .transpose (quants ).squeeze ():
244
257
# Multiple y values
245
- interval_plot = _plot_tree (interval_plot , y , q , quartiles )
258
+ interval_plot = _plot_tree (interval_plot , y , q , quartiles ,
259
+ ** plot_kwargs )
246
260
y -= 1
247
261
else :
248
- interval_plot = _plot_tree (interval_plot , y , quants , quartiles )
262
+ interval_plot = _plot_tree (interval_plot , y , quants , quartiles ,
263
+ ** plot_kwargs )
249
264
250
265
# Increment index
251
266
var += k
@@ -264,7 +279,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
264
279
265
280
# Add variable labels
266
281
interval_plot .set_yticks ([- l for l in range (len (labels ))])
267
- interval_plot .set_yticklabels (labels )
282
+ interval_plot .set_yticklabels (labels , fontsize = plot_kwargs . get ( 'fontsize' , None ) )
268
283
269
284
# Add title
270
285
plot_title = ""
@@ -273,7 +288,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
273
288
elif main :
274
289
plot_title = main
275
290
if plot_title :
276
- interval_plot .set_title (plot_title )
291
+ interval_plot .set_title (plot_title , fontsize = plot_kwargs . get ( 'fontsize' , None ) )
277
292
278
293
# Add x-axis label
279
294
if xtitle is not None :
@@ -293,7 +308,7 @@ def forestplot(trace_obj, varnames=None, transform=identity_transform, alpha=0.0
293
308
spine .set_color ('none' ) # don't draw spine
294
309
295
310
# Reference line
296
- interval_plot .axvline (vline , color = 'k' , linestyle = '-- ' )
311
+ interval_plot .axvline (vline , color = 'k' , linestyle = ': ' )
297
312
298
313
# Genenerate Gelman-Rubin plot
299
314
if plot_rhat :
0 commit comments