@@ -103,7 +103,8 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS,
103
103
def sample (draws = 500 , step = None , init = 'auto' , n_init = 200000 , start = None ,
104
104
trace = None , chain = 0 , njobs = 1 , tune = 500 , nuts_kwargs = None ,
105
105
step_kwargs = None , progressbar = True , model = None , random_seed = - 1 ,
106
- live_plot = False , discard_tuned_samples = True , ** kwargs ):
106
+ live_plot = False , discard_tuned_samples = True , live_plot_kwargs = None ,
107
+ ** kwargs ):
107
108
"""Draw samples from the posterior using the given step methods.
108
109
109
110
Multiple step methods are supported via compound step methods.
@@ -185,6 +186,8 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
185
186
A list is accepted if more if `njobs` is greater than one.
186
187
live_plot : bool
187
188
Flag for live plotting the trace while sampling
189
+ live_plot_kwargs : dict
190
+ Options for traceplot. Example: live_plot_kwargs={'varnames': ['x']}
188
191
discard_tuned_samples : bool
189
192
Whether to discard posterior samples of the tune interval.
190
193
@@ -254,6 +257,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
254
257
'model' : model ,
255
258
'random_seed' : random_seed ,
256
259
'live_plot' : live_plot ,
260
+ 'live_plot_kwargs' : live_plot_kwargs ,
257
261
}
258
262
259
263
sample_args .update (kwargs )
@@ -271,7 +275,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
271
275
272
276
def _sample (draws , step = None , start = None , trace = None , chain = 0 , tune = None ,
273
277
progressbar = True , model = None , random_seed = - 1 , live_plot = False ,
274
- ** kwargs ):
278
+ live_plot_kwargs = None , ** kwargs ):
275
279
skip_first = kwargs .get ('skip_first' , 0 )
276
280
refresh_every = kwargs .get ('refresh_every' , 100 )
277
281
@@ -283,12 +287,14 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
283
287
strace = None
284
288
for it , strace in enumerate (sampling ):
285
289
if live_plot :
290
+ if live_plot_kwargs is None :
291
+ live_plot_kwargs = {}
286
292
if it >= skip_first :
287
293
trace = MultiTrace ([strace ])
288
294
if it == skip_first :
289
- ax = traceplot (trace , live_plot = False , ** kwargs )
295
+ ax = traceplot (trace , live_plot = False , ** live_plot_kwargs )
290
296
elif (it - skip_first ) % refresh_every == 0 or it == draws - 1 :
291
- traceplot (trace , ax = ax , live_plot = True , ** kwargs )
297
+ traceplot (trace , ax = ax , live_plot = True , ** live_plot_kwargs )
292
298
except KeyboardInterrupt :
293
299
pass
294
300
finally :
0 commit comments