@@ -128,6 +128,87 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
128
128
)
129
129
130
130
131
+ class _ChainProgress :
132
+ bar : Any
133
+ total : int
134
+ finished : bool
135
+ tuning : bool
136
+ draws : int
137
+ started : bool
138
+ chain_id : int
139
+ num_divs : int
140
+ step_size : float
141
+ num_steps : int
142
+
143
+ def __init__ (self , total , chain_id ):
144
+ self .bar = fastprogress .progress_bar (range (total ))
145
+ self .total = total
146
+ self .finished = False
147
+ self .tuning = True
148
+ self .draws = 0
149
+ self .started = False
150
+ self .chain_id = chain_id
151
+ self .num_divs = 0
152
+ self .step_size = 0.
153
+ self .num_steps = 0
154
+ self .bar .update (0 )
155
+
156
+ def callback (self , info ):
157
+ try :
158
+ if self .finished :
159
+ return
160
+
161
+ if info .finished_draws == info .total_draws :
162
+ self .finished = True
163
+
164
+ if not info .tuning :
165
+ self .tuning = False
166
+
167
+ self .draws = info .finished_draws
168
+ if info .started :
169
+ self .started = True
170
+
171
+ self .num_divs = info .divergences
172
+ self .step_size = info .step_size
173
+ self .num_steps = info .num_steps
174
+
175
+ if self .tuning :
176
+ state = "warmup"
177
+ else :
178
+ state = "sampling"
179
+
180
+ self .bar .comment = (
181
+ f"Chain { self .chain_id :2} { state } : "
182
+ f"trajectory { self .num_steps :3} / "
183
+ f"diverging { self .num_divs : 2} / "
184
+ f"step { self .step_size :.2g} "
185
+ )
186
+ self .bar .update (self .draws )
187
+ except Exception as e :
188
+ print (e )
189
+
190
+
191
+ class _DetailedProgress :
192
+ chains : list [_ChainProgress ]
193
+
194
+ def __init__ (self , total_draws , num_chains ):
195
+ self .chains = [_ChainProgress (total_draws , i ) for i in range (num_chains )]
196
+
197
+ def callback (self , info ):
198
+ for chain , chain_info in zip (self .chains , info ):
199
+ chain .callback (chain_info )
200
+
201
+
202
+ class _SummaryProgress :
203
+ bar : Any
204
+
205
+ def __init__ (self , total_draws , num_chains ):
206
+ pass
207
+
208
+ def callback (self , info ):
209
+ return None
210
+
211
+
131
212
class _BackgroundSampler :
132
213
_sampler : Any
133
214
_num_divs : int
@@ -144,81 +225,30 @@ def __init__(
144
225
compiled_model ,
145
226
settings ,
146
227
init_mean ,
147
- chains ,
148
228
cores ,
149
- seed ,
150
- draws ,
151
- tune ,
152
229
* ,
153
230
progress_bar = True ,
154
231
save_warmup = True ,
155
232
return_raw_trace = False ,
156
233
):
157
- self ._num_divs = 0
158
- self ._tune = settings .num_tune
159
- self ._draws = settings .num_draws
160
234
self ._settings = settings
161
- self ._chains_tuning = chains
162
- self ._chains_finished = 0
163
- self ._chains = chains
164
235
self ._compiled_model = compiled_model
165
236
self ._save_warmup = save_warmup
166
237
self ._return_raw_trace = return_raw_trace
167
- total_draws = (self ._draws + self ._tune ) * self ._chains
168
- self ._progress = fastprogress .progress_bar (
169
- range (total_draws ),
170
- total = total_draws ,
171
- display = progress_bar ,
172
- )
173
- # fastprogress seems to reset the progress bar
174
- # if we create a new iterator, but we don't want
175
- # this for multiple calls to wait.
176
- self ._bar = iter (self ._progress )
177
-
178
- self ._exit_event = Event ()
179
- self ._pause_event = Event ()
180
- self ._continue = Condition ()
181
-
182
- self ._finished_draws = 0
183
-
184
- next (self ._bar )
185
-
186
- def progress_callback (info ):
187
- if info .draw == self ._tune - 1 :
188
- self ._chains_tuning -= 1
189
- if info .draw == self ._tune + self ._draws - 1 :
190
- self ._chains_finished += 1
191
- if info .is_diverging and info .draw > self ._tune :
192
- self ._num_divs += 1
193
- if self ._chains_tuning > 0 :
194
- count = self ._chains_tuning
195
- divs = self ._num_divs
196
- self ._progress .comment = (
197
- f" Chains in warmup: { count } , Divergences: { divs } "
198
- )
199
- else :
200
- count = self ._chains - self ._chains_finished
201
- divs = self ._num_divs
202
- self ._progress .comment = (
203
- f" Sampling chains: { count } , Divergences: { divs } "
204
- )
205
- try :
206
- next (self ._bar )
207
- except StopIteration :
208
- pass
209
- self ._finished_draws += 1
238
+
239
+ total_draws = settings .num_draws + settings .num_tune
210
240
211
241
if progress_bar :
212
- callback = progress_callback
242
+ self ._progress = _DetailedProgress (total_draws , settings .num_chains )
243
+ callback = self ._progress .callback
213
244
else :
245
+ self ._progress = None
214
246
callback = None
215
247
216
248
self ._sampler = compiled_model ._make_sampler (
217
249
settings ,
218
250
init_mean ,
219
- chains ,
220
251
cores ,
221
- seed ,
222
252
callback = callback ,
223
253
)
224
254
@@ -233,12 +263,10 @@ def wait(self, *, timeout=None):
233
263
This resumes the sampler in case it had been paused.
234
264
"""
235
265
self ._sampler .wait (timeout )
236
- self ._sampler .finalize ()
237
- return self ._extract ()
238
-
239
- def _extract (self ):
240
266
results = self ._sampler .extract_results ()
267
+ return self ._extract (results )
241
268
269
+ def _extract (self , results ):
242
270
dims = {name : list (dim ) for name , dim in self ._compiled_model .dims .items ()}
243
271
dims ["mass_matrix_inv" ] = ["unconstrained_parameter" ]
244
272
dims ["gradient" ] = ["unconstrained_parameter" ]
@@ -253,7 +281,7 @@ def _extract(self):
253
281
else :
254
282
return _trace_to_arviz (
255
283
results ,
256
- self ._tune ,
284
+ self ._settings . num_tune ,
257
285
self ._compiled_model .shapes ,
258
286
dims = dims ,
259
287
coords = {
@@ -263,6 +291,11 @@ def _extract(self):
263
291
save_warmup = self ._save_warmup ,
264
292
)
265
293
294
+ def inspect (self ):
295
+ """Get a copy of the current state of the trace"""
296
+ results = self ._sampler .inspect ()
297
+ return self ._extract (results )
298
+
266
299
def pause (self ):
267
300
"""Pause the sampler."""
268
301
self ._sampler .pause ()
@@ -278,7 +311,8 @@ def is_finished(self):
278
311
def abort (self ):
279
312
"""Abort sampling and return the trace produced so far."""
280
313
self ._sampler .abort ()
281
- return self ._extract ()
314
+ results = self ._sampler .extract_results ()
315
+ return self ._extract (results )
282
316
283
317
def cancel (self ):
284
318
"""Abort sampling and discard progress."""
@@ -402,9 +436,10 @@ def sample(
402
436
trace : arviz.InferenceData
403
437
An ArviZ ``InferenceData`` object that contains the samples.
404
438
"""
405
- settings = _lib .PySamplerArgs ( )
439
+ settings = _lib .PyDiagGradNutsSettings ( seed )
406
440
settings .num_tune = tune
407
441
settings .num_draws = draws
442
+ settings .num_chains = chains
408
443
409
444
for name , val in kwargs .items ():
410
445
setattr (settings , name , val )
@@ -424,11 +459,7 @@ def sample(
424
459
compiled_model ,
425
460
settings ,
426
461
init_mean ,
427
- chains ,
428
462
cores ,
429
- seed ,
430
- draws ,
431
- tune ,
432
463
progress_bar = progress_bar ,
433
464
save_warmup = save_warmup ,
434
465
return_raw_trace = return_raw_trace ,
0 commit comments