Skip to content

Commit 3d83c6d

Browse files
authored
Change sample() to use live_plot_kwargs instead of **kwargs. (#2235)
* Change sample() to use live_plot_kwargs instead of **kwargs. * Pass kwargs of sample to _sample. * Typo wargs -> kwargs. * Re-add kwargs to _sample().
1 parent 69e223b commit 3d83c6d

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pymc3/sampling.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS,
103103
def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
104104
trace=None, chain=0, njobs=1, tune=500, nuts_kwargs=None,
105105
step_kwargs=None, progressbar=True, model=None, random_seed=-1,
106-
live_plot=False, discard_tuned_samples=True, **kwargs):
106+
live_plot=False, discard_tuned_samples=True, live_plot_kwargs=None,
107+
**kwargs):
107108
"""Draw samples from the posterior using the given step methods.
108109
109110
Multiple step methods are supported via compound step methods.
@@ -185,6 +186,8 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
185186
A list is accepted if more if `njobs` is greater than one.
186187
live_plot : bool
187188
Flag for live plotting the trace while sampling
189+
live_plot_kwargs : dict
190+
Options for traceplot. Example: live_plot_kwargs={'varnames': ['x']}
188191
discard_tuned_samples : bool
189192
Whether to discard posterior samples of the tune interval.
190193
@@ -254,6 +257,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
254257
'model': model,
255258
'random_seed': random_seed,
256259
'live_plot': live_plot,
260+
'live_plot_kwargs': live_plot_kwargs,
257261
}
258262

259263
sample_args.update(kwargs)
@@ -271,7 +275,7 @@ def sample(draws=500, step=None, init='auto', n_init=200000, start=None,
271275

272276
def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
273277
progressbar=True, model=None, random_seed=-1, live_plot=False,
274-
**kwargs):
278+
live_plot_kwargs=None, **kwargs):
275279
skip_first = kwargs.get('skip_first', 0)
276280
refresh_every = kwargs.get('refresh_every', 100)
277281

@@ -283,12 +287,14 @@ def _sample(draws, step=None, start=None, trace=None, chain=0, tune=None,
283287
strace = None
284288
for it, strace in enumerate(sampling):
285289
if live_plot:
290+
if live_plot_kwargs is None:
291+
live_plot_kwargs = {}
286292
if it >= skip_first:
287293
trace = MultiTrace([strace])
288294
if it == skip_first:
289-
ax = traceplot(trace, live_plot=False, **kwargs)
295+
ax = traceplot(trace, live_plot=False, **live_plot_kwargs)
290296
elif (it - skip_first) % refresh_every == 0 or it == draws - 1:
291-
traceplot(trace, ax=ax, live_plot=True, **kwargs)
297+
traceplot(trace, ax=ax, live_plot=True, **live_plot_kwargs)
292298
except KeyboardInterrupt:
293299
pass
294300
finally:

0 commit comments

Comments
 (0)