1
1
from collections import namedtuple
2
2
import warnings
3
3
4
- from ..arraystep import Competence
4
+ from ..arraystep import Competence , SamplingError
5
5
from .base_hmc import BaseHMC
6
6
from pymc3 .theanof import floatX
7
7
from pymc3 .vartypes import continuous_types
8
8
9
9
import numpy as np
10
10
import numpy .random as nr
11
- from scipy import stats
11
+ from scipy import stats , linalg
12
+ import six
12
13
13
14
__all__ = ['NUTS' ]
14
15
@@ -87,7 +88,7 @@ class NUTS(BaseHMC):
87
88
88
89
def __init__ (self , vars = None , Emax = 1000 , target_accept = 0.8 ,
89
90
gamma = 0.05 , k = 0.75 , t0 = 10 , adapt_step_size = True ,
90
- max_treedepth = 10 , ** kwargs ):
91
+ max_treedepth = 10 , on_error = 'summary' , ** kwargs ):
91
92
R"""
92
93
Parameters
93
94
----------
@@ -124,6 +125,12 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
124
125
this will be interpreded as the mass or covariance matrix.
125
126
is_cov : bool, default=False
126
127
Treat the scaling as mass or covariance matrix.
128
+ on_error : {'summary', 'warn', 'raise'}, default='summary'
129
+ How to report problems during sampling.
130
+
131
+ * `summary`: Print one warning after sampling.
132
+ * `warn`: Print individual warnings as soon as they appear.
133
+ * `raise`: Raise an error on the first problem.
127
134
potential : Potential, optional
128
135
An object that represents the Hamiltonian with methods `velocity`,
129
136
`energy`, and `random` methods. It can be specified instead
@@ -156,11 +163,14 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
156
163
self .max_treedepth = max_treedepth
157
164
158
165
self .tune = True
166
+ self .report = NutsReport (on_error , max_treedepth , target_accept )
159
167
160
168
def astep (self , q0 ):
161
169
p0 = self .potential .random ()
162
170
v0 = self .compute_velocity (p0 )
163
171
start_energy = self .compute_energy (q0 , p0 )
172
+ if not np .isfinite (start_energy ):
173
+ raise ValueError ('The initial energy is inf or nan.' )
164
174
165
175
if not self .adapt_step_size :
166
176
step_size = self .step_size
@@ -170,14 +180,16 @@ def astep(self, q0):
170
180
step_size = np .exp (self .log_step_size_bar )
171
181
172
182
start = Edge (q0 , p0 , v0 , self .dlogp (q0 ), start_energy )
173
- tree = Tree (len (p0 ), self .leapfrog , start , step_size , self .Emax )
183
+ tree = _Tree (len (p0 ), self .leapfrog , start , step_size , self .Emax )
174
184
175
185
for _ in range (self .max_treedepth ):
176
186
direction = logbern (np .log (0.5 )) * 2 - 1
177
187
diverging , turning = tree .extend (direction )
178
188
q = tree .proposal .q
179
189
180
190
if diverging or turning :
191
+ if diverging :
192
+ self .report ._add_divergence (self .tune , * diverging )
181
193
break
182
194
183
195
w = 1. / (self .m + self .t0 )
@@ -208,64 +220,6 @@ def competence(var):
208
220
return Competence .IDEAL
209
221
return Competence .INCOMPATIBLE
210
222
211
- def check_trace (self , strace ):
212
- """Print warnings for obviously problematic chains."""
213
- n = len (strace )
214
- chain = strace .chain
215
-
216
- diverging = strace .get_sampler_stats ('diverging' )
217
- if diverging .ndim == 2 :
218
- diverging = np .any (diverging , axis = - 1 )
219
-
220
- tuning = strace .get_sampler_stats ('tune' )
221
- if tuning .ndim == 2 :
222
- tuning = np .any (tuning , axis = - 1 )
223
-
224
- accept = strace .get_sampler_stats ('mean_tree_accept' )
225
- if accept .ndim == 2 :
226
- accept = np .mean (accept , axis = - 1 )
227
-
228
- depth = strace .get_sampler_stats ('depth' )
229
- if depth .ndim == 2 :
230
- depth = np .max (depth , axis = - 1 )
231
-
232
- n_samples = n - (~ tuning ).sum ()
233
-
234
- if n < 1000 :
235
- warnings .warn ('Chain %s contains only %s samples.' % (chain , n ))
236
- if np .all (tuning ):
237
- warnings .warn ('Step size tuning was enabled throughout the whole '
238
- 'trace. You might want to specify the number of '
239
- 'tuning steps.' )
240
- if np .all (diverging ):
241
- warnings .warn ('Chain %s contains only diverging samples. '
242
- 'The model is probably misspecified.' % chain )
243
- return
244
- if np .any (diverging [~ tuning ]):
245
- warnings .warn ("Chain %s contains diverging samples after tuning. "
246
- "If increasing `target_accept` doesn't help, "
247
- "try to reparameterize." % chain )
248
- if n_samples > 0 :
249
- depth_samples = depth [~ tuning ]
250
- else :
251
- depth_samples = depth [n // 2 :]
252
- if np .mean (depth_samples == self .max_treedepth ) > 0.05 :
253
- warnings .warn ('Chain %s reached the maximum tree depth. Increase '
254
- 'max_treedepth, increase target_accept or '
255
- 'reparameterize.' % chain )
256
-
257
- mean_accept = np .mean (accept [~ tuning ])
258
- target_accept = self .target_accept
259
- # Try to find a reasonable interval for acceptable acceptance
260
- # probabilities. Finding this was mostry trial and error.
261
- n_bound = min (100 , n )
262
- n_good , n_bad = mean_accept * n_bound , (1 - mean_accept ) * n_bound
263
- lower , upper = stats .beta (n_good + 1 , n_bad + 1 ).interval (0.95 )
264
- if target_accept < lower or target_accept > upper :
265
- warnings .warn ('The acceptance probability in chain %s does not '
266
- 'match the target. It is %s, but should be close '
267
- 'to %s. Try to increase the number of tuning steps.'
268
- % (chain , mean_accept , target_accept ))
269
223
270
224
# A node in the NUTS tree that is at the far right or left of the tree
271
225
Edge = namedtuple ("Edge" , 'q, p, v, q_grad, energy' )
@@ -279,7 +233,7 @@ def check_trace(self, strace):
279
233
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals" )
280
234
281
235
282
- class Tree (object ):
236
+ class _Tree (object ):
283
237
def __init__ (self , ndim , leapfrog , start , step_size , Emax ):
284
238
"""Binary tree from the NUTS algorithm.
285
239
@@ -352,24 +306,41 @@ def extend(self, direction):
352
306
353
307
return diverging , turning
354
308
355
- def _build_subtree (self , left , depth , epsilon ):
356
- if depth == 0 :
309
+ def _single_step (self , left , epsilon ):
310
+ """Perform a leapfrog step and handle error cases."""
311
+ try :
357
312
right = self .leapfrog (left .q , left .p , left .q_grad , epsilon )
313
+ except linalg .LinalgError as error :
314
+ error_msg = "LinAlgError during leapfrog step."
315
+ except ValueError as error :
316
+ # Raised by many scipy.linalg functions
317
+ if error .args [0 ].lower () == 'array must not contain infs or nans' :
318
+ error_msg = "Infs or nans in scipy.linalg during leapfrog step."
319
+ else :
320
+ raise
321
+ else :
358
322
right = Edge (* right )
359
323
energy_change = right .energy - self .start_energy
360
324
if np .isnan (energy_change ):
361
325
energy_change = np .inf
362
326
363
327
if np .abs (energy_change ) > np .abs (self .max_energy_change ):
364
328
self .max_energy_change = energy_change
365
- p_accept = min (1 , np .exp (- energy_change ))
366
-
367
- log_size = - energy_change
368
- diverging = energy_change > self .Emax
329
+ if np .abs (energy_change ) < self .Emax :
330
+ p_accept = min (1 , np .exp (- energy_change ))
331
+ log_size = - energy_change
332
+ proposal = Proposal (right .q , right .energy , p_accept )
333
+ tree = Subtree (right , right , right .p , proposal , log_size , p_accept , 1 )
334
+ return tree , False , False
335
+ else :
336
+ error_msg = "Bad energy after leapfrog step."
337
+ error = None
338
+ tree = Subtree (None , None , None , None , - np .inf , 0 , 1 )
339
+ return tree , (error_msg , error ), False
369
340
370
- proposal = Proposal ( right . q , right . energy , p_accept )
371
- tree = Subtree ( right , right , right . p , proposal , log_size , p_accept , 1 )
372
- return tree , diverging , False
341
+ def _build_subtree ( self , left , depth , epsilon ):
342
+ if depth == 0 :
343
+ return self . _single_step ( left , epsilon )
373
344
374
345
tree1 , diverging , turning = self ._build_subtree (left , depth - 1 , epsilon )
375
346
if diverging or turning :
@@ -408,3 +379,91 @@ def stats(self):
408
379
'tree_size' : self .n_proposals ,
409
380
'max_energy_error' : self .max_energy_change ,
410
381
}
382
+
383
+
384
+ class NutsReport (object ):
385
+ def __init__ (self , on_error , max_treedepth , target_accept ):
386
+ if on_error not in ['summary' , 'raise' , 'warn' ]:
387
+ raise ValueError ('Invalid value for on_error.' )
388
+ self ._on_error = on_error
389
+ self ._max_treedepth = max_treedepth
390
+ self ._target_accept = target_accept
391
+ self ._chain_id = None
392
+ self ._divs_tune = []
393
+ self ._divs_after_tune = []
394
+
395
+ def _add_divergence (self , tuning , msg , error = None ):
396
+ if tuning :
397
+ self ._divs_tune .append ((msg , error ))
398
+ else :
399
+ self ._divs_after_tune ((msg , error ))
400
+ if self ._on_error == 'raise' :
401
+ err = SamplingError ('Divergence after tuning: ' + msg )
402
+ six .raise_from (err , error )
403
+ elif self ._on_error == 'warn' :
404
+ warnings .warn ('Divergence detected: ' + msg )
405
+
406
+ def _check_len (self , tuning ):
407
+ n = (~ tuning ).sum ()
408
+ if n < 1000 :
409
+ warnings .warn ('Chain %s contains only %s samples.'
410
+ % (self ._chain_id , n ))
411
+ if np .all (tuning ):
412
+ warnings .warn ('Step size tuning was enabled throughout the whole '
413
+ 'trace. You might want to specify the number of '
414
+ 'tuning steps.' )
415
+ if n == len (self ._divs_after_tune ):
416
+ warnings .warn ('Chain %s contains only diverging samples. '
417
+ 'The model is probably misspecified.'
418
+ % self ._chain_id )
419
+
420
+ def _check_accept (self , accept ):
421
+ mean_accept = np .mean (accept )
422
+ target_accept = self ._target_accept
423
+ # Try to find a reasonable interval for acceptable acceptance
424
+ # probabilities. Finding this was mostry trial and error.
425
+ n_bound = min (100 , len (accept ))
426
+ n_good , n_bad = mean_accept * n_bound , (1 - mean_accept ) * n_bound
427
+ lower , upper = stats .beta (n_good + 1 , n_bad + 1 ).interval (0.95 )
428
+ if target_accept < lower or target_accept > upper :
429
+ warnings .warn ('The acceptance probability in chain %s does not '
430
+ 'match the target. It is %s, but should be close '
431
+ 'to %s. Try to increase the number of tuning steps.'
432
+ % (self ._chain_id , mean_accept , target_accept ))
433
+
434
+ def _check_depth (self , depth ):
435
+ if len (depth ) == 0 :
436
+ return
437
+ if np .mean (depth == self ._max_treedepth ) > 0.05 :
438
+ warnings .warn ('Chain %s reached the maximum tree depth. Increase '
439
+ 'max_treedepth, increase target_accept or '
440
+ 'reparameterize.' % self ._chain_id )
441
+
442
+ def _check_divergence (self ):
443
+ n_diverging = len (self ._divs_after_tune )
444
+ if n_diverging > 0 :
445
+ warnings .warn ("Chain %s contains %s diverging samples after "
446
+ "tuning. If increasing `target_accept` doesn't help "
447
+ "try to reparameterize."
448
+ % (self ._chain_id , n_diverging ))
449
+
450
+ def _finalize (self , strace ):
451
+ """Print warnings for obviously problematic chains."""
452
+ self ._chain_id = strace .chain
453
+
454
+ tuning = strace .get_sampler_stats ('tune' )
455
+ if tuning .ndim == 2 :
456
+ tuning = np .any (tuning , axis = - 1 )
457
+
458
+ accept = strace .get_sampler_stats ('mean_tree_accept' )
459
+ if accept .ndim == 2 :
460
+ accept = np .mean (accept , axis = - 1 )
461
+
462
+ depth = strace .get_sampler_stats ('depth' )
463
+ if depth .ndim == 2 :
464
+ depth = np .max (depth , axis = - 1 )
465
+
466
+ self ._check_len (tuning )
467
+ self ._check_depth (depth [~ tuning ])
468
+ self ._check_accept (accept [~ tuning ])
469
+ self ._check_divergence ()
0 commit comments