2
2
import warnings
3
3
4
4
from ..arraystep import Competence
5
+ from pymc3 .exceptions import SamplingError
5
6
from .base_hmc import BaseHMC
6
7
from pymc3 .theanof import floatX
7
8
from pymc3 .vartypes import continuous_types
8
9
9
10
import numpy as np
10
11
import numpy .random as nr
11
- from scipy import stats
12
+ from scipy import stats , linalg
13
+ import six
12
14
13
15
__all__ = ['NUTS' ]
14
16
@@ -87,7 +89,7 @@ class NUTS(BaseHMC):
87
89
88
90
def __init__ (self , vars = None , Emax = 1000 , target_accept = 0.8 ,
89
91
gamma = 0.05 , k = 0.75 , t0 = 10 , adapt_step_size = True ,
90
- max_treedepth = 10 , ** kwargs ):
92
+ max_treedepth = 10 , on_error = 'summary' , ** kwargs ):
91
93
R"""
92
94
Parameters
93
95
----------
@@ -124,6 +126,12 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
124
126
this will be interpreded as the mass or covariance matrix.
125
127
is_cov : bool, default=False
126
128
Treat the scaling as mass or covariance matrix.
129
+ on_error : {'summary', 'warn', 'raise'}, default='summary'
130
+ How to report problems during sampling.
131
+
132
+ * `summary`: Print one warning after sampling.
133
+ * `warn`: Print individual warnings as soon as they appear.
134
+ * `raise`: Raise an error on the first problem.
127
135
potential : Potential, optional
128
136
An object that represents the Hamiltonian with methods `velocity`,
129
137
`energy`, and `random` methods. It can be specified instead
@@ -156,11 +164,15 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
156
164
self .max_treedepth = max_treedepth
157
165
158
166
self .tune = True
167
+ self .report = NutsReport (on_error , max_treedepth , target_accept )
159
168
160
169
def astep (self , q0 ):
161
170
p0 = self .potential .random ()
162
171
v0 = self .compute_velocity (p0 )
163
172
start_energy = self .compute_energy (q0 , p0 )
173
+ if not np .isfinite (start_energy ):
174
+ raise ValueError ('Bad initial energy: %s. The model '
175
+ 'might be misspecified.' % start_energy )
164
176
165
177
if not self .adapt_step_size :
166
178
step_size = self .step_size
@@ -170,14 +182,16 @@ def astep(self, q0):
170
182
step_size = np .exp (self .log_step_size_bar )
171
183
172
184
start = Edge (q0 , p0 , v0 , self .dlogp (q0 ), start_energy )
173
- tree = Tree (len (p0 ), self .leapfrog , start , step_size , self .Emax )
185
+ tree = _Tree (len (p0 ), self .leapfrog , start , step_size , self .Emax )
174
186
175
187
for _ in range (self .max_treedepth ):
176
188
direction = logbern (np .log (0.5 )) * 2 - 1
177
189
diverging , turning = tree .extend (direction )
178
190
q = tree .proposal .q
179
191
180
192
if diverging or turning :
193
+ if diverging :
194
+ self .report ._add_divergence (self .tune , * diverging )
181
195
break
182
196
183
197
w = 1. / (self .m + self .t0 )
@@ -208,64 +222,6 @@ def competence(var):
208
222
return Competence .IDEAL
209
223
return Competence .INCOMPATIBLE
210
224
211
- def check_trace (self , strace ):
212
- """Print warnings for obviously problematic chains."""
213
- n = len (strace )
214
- chain = strace .chain
215
-
216
- diverging = strace .get_sampler_stats ('diverging' )
217
- if diverging .ndim == 2 :
218
- diverging = np .any (diverging , axis = - 1 )
219
-
220
- tuning = strace .get_sampler_stats ('tune' )
221
- if tuning .ndim == 2 :
222
- tuning = np .any (tuning , axis = - 1 )
223
-
224
- accept = strace .get_sampler_stats ('mean_tree_accept' )
225
- if accept .ndim == 2 :
226
- accept = np .mean (accept , axis = - 1 )
227
-
228
- depth = strace .get_sampler_stats ('depth' )
229
- if depth .ndim == 2 :
230
- depth = np .max (depth , axis = - 1 )
231
-
232
- n_samples = n - (~ tuning ).sum ()
233
-
234
- if n < 1000 :
235
- warnings .warn ('Chain %s contains only %s samples.' % (chain , n ))
236
- if np .all (tuning ):
237
- warnings .warn ('Step size tuning was enabled throughout the whole '
238
- 'trace. You might want to specify the number of '
239
- 'tuning steps.' )
240
- if np .all (diverging ):
241
- warnings .warn ('Chain %s contains only diverging samples. '
242
- 'The model is probably misspecified.' % chain )
243
- return
244
- if np .any (diverging [~ tuning ]):
245
- warnings .warn ("Chain %s contains diverging samples after tuning. "
246
- "If increasing `target_accept` doesn't help, "
247
- "try to reparameterize." % chain )
248
- if n_samples > 0 :
249
- depth_samples = depth [~ tuning ]
250
- else :
251
- depth_samples = depth [n // 2 :]
252
- if np .mean (depth_samples == self .max_treedepth ) > 0.05 :
253
- warnings .warn ('Chain %s reached the maximum tree depth. Increase '
254
- 'max_treedepth, increase target_accept or '
255
- 'reparameterize.' % chain )
256
-
257
- mean_accept = np .mean (accept [~ tuning ])
258
- target_accept = self .target_accept
259
- # Try to find a reasonable interval for acceptable acceptance
260
- # probabilities. Finding this was mostry trial and error.
261
- n_bound = min (100 , n )
262
- n_good , n_bad = mean_accept * n_bound , (1 - mean_accept ) * n_bound
263
- lower , upper = stats .beta (n_good + 1 , n_bad + 1 ).interval (0.95 )
264
- if target_accept < lower or target_accept > upper :
265
- warnings .warn ('The acceptance probability in chain %s does not '
266
- 'match the target. It is %s, but should be close '
267
- 'to %s. Try to increase the number of tuning steps.'
268
- % (chain , mean_accept , target_accept ))
269
225
270
226
# A node in the NUTS tree that is at the far right or left of the tree
271
227
Edge = namedtuple ("Edge" , 'q, p, v, q_grad, energy' )
@@ -279,7 +235,7 @@ def check_trace(self, strace):
279
235
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals" )
280
236
281
237
282
- class Tree (object ):
238
+ class _Tree (object ):
283
239
def __init__ (self , ndim , leapfrog , start , step_size , Emax ):
284
240
"""Binary tree from the NUTS algorithm.
285
241
@@ -352,24 +308,45 @@ def extend(self, direction):
352
308
353
309
return diverging , turning
354
310
355
- def _build_subtree (self , left , depth , epsilon ):
356
- if depth == 0 :
311
+ def _single_step (self , left , epsilon ):
312
+ """Perform a leapfrog step and handle error cases."""
313
+ try :
357
314
right = self .leapfrog (left .q , left .p , left .q_grad , epsilon )
315
+ except linalg .LinAlgError as err :
316
+ error_msg = "LinAlgError during leapfrog step."
317
+ error = err
318
+ except ValueError as err :
319
+ # Raised by many scipy.linalg functions
320
+ scipy_msg = "array must not contain infs or nans"
321
+ if len (err .args ) > 0 and scipy_msg in err .args [0 ].lower ():
322
+ error_msg = "Infs or nans in scipy.linalg during leapfrog step."
323
+ error = err
324
+ else :
325
+ raise
326
+ else :
358
327
right = Edge (* right )
359
328
energy_change = right .energy - self .start_energy
360
329
if np .isnan (energy_change ):
361
330
energy_change = np .inf
362
331
363
332
if np .abs (energy_change ) > np .abs (self .max_energy_change ):
364
333
self .max_energy_change = energy_change
365
- p_accept = min (1 , np .exp (- energy_change ))
366
-
367
- log_size = - energy_change
368
- diverging = energy_change > self .Emax
334
+ if np .abs (energy_change ) < self .Emax :
335
+ p_accept = min (1 , np .exp (- energy_change ))
336
+ log_size = - energy_change
337
+ proposal = Proposal (right .q , right .energy , p_accept )
338
+ tree = Subtree (right , right , right .p , proposal , log_size , p_accept , 1 )
339
+ return tree , False , False
340
+ else :
341
+ error_msg = ("Energy change in leapfrog step is too large: %s. "
342
+ % energy_change )
343
+ error = None
344
+ tree = Subtree (None , None , None , None , - np .inf , 0 , 1 )
345
+ return tree , (error_msg , error , left ), False
369
346
370
- proposal = Proposal ( right . q , right . energy , p_accept )
371
- tree = Subtree ( right , right , right . p , proposal , log_size , p_accept , 1 )
372
- return tree , diverging , False
347
+ def _build_subtree ( self , left , depth , epsilon ):
348
+ if depth == 0 :
349
+ return self . _single_step ( left , epsilon )
373
350
374
351
tree1 , diverging , turning = self ._build_subtree (left , depth - 1 , epsilon )
375
352
if diverging or turning :
@@ -408,3 +385,93 @@ def stats(self):
408
385
'tree_size' : self .n_proposals ,
409
386
'max_energy_error' : self .max_energy_change ,
410
387
}
388
+
389
+
390
+ class NutsReport (object ):
391
+ def __init__ (self , on_error , max_treedepth , target_accept ):
392
+ if on_error not in ['summary' , 'raise' , 'warn' ]:
393
+ raise ValueError ('Invalid value for on_error.' )
394
+ self ._on_error = on_error
395
+ self ._max_treedepth = max_treedepth
396
+ self ._target_accept = target_accept
397
+ self ._chain_id = None
398
+ self ._divs_tune = []
399
+ self ._divs_after_tune = []
400
+
401
+ def _add_divergence (self , tuning , msg , error = None , point = None ):
402
+ if tuning :
403
+ self ._divs_tune .append ((msg , error , point ))
404
+ else :
405
+ self ._divs_after_tune .append ((msg , error , point ))
406
+ if self ._on_error == 'raise' :
407
+ err = SamplingError ('Divergence after tuning: %s Increase '
408
+ 'target_accept or reparameterize.' % msg )
409
+ six .raise_from (err , error )
410
+ elif self ._on_error == 'warn' :
411
+ warnings .warn ('Divergence detected: %s Increase target_accept '
412
+ 'or reparameterize.' % msg )
413
+
414
+ def _check_len (self , tuning ):
415
+ n = (~ tuning ).sum ()
416
+ if n < 500 :
417
+ warnings .warn ('Chain %s contains only %s samples.'
418
+ % (self ._chain_id , n ))
419
+ if np .all (tuning ):
420
+ warnings .warn ('Step size tuning was enabled throughout the whole '
421
+ 'trace. You might want to specify the number of '
422
+ 'tuning steps.' )
423
+ if n == len (self ._divs_after_tune ):
424
+ warnings .warn ('Chain %s contains only diverging samples. '
425
+ 'The model is probably misspecified.'
426
+ % self ._chain_id )
427
+
428
+ def _check_accept (self , accept ):
429
+ mean_accept = np .mean (accept )
430
+ target_accept = self ._target_accept
431
+ # Try to find a reasonable interval for acceptable acceptance
432
+ # probabilities. Finding this was mostry trial and error.
433
+ n_bound = min (100 , len (accept ))
434
+ n_good , n_bad = mean_accept * n_bound , (1 - mean_accept ) * n_bound
435
+ lower , upper = stats .beta (n_good + 1 , n_bad + 1 ).interval (0.95 )
436
+ if target_accept < lower or target_accept > upper :
437
+ warnings .warn ('The acceptance probability in chain %s does not '
438
+ 'match the target. It is %s, but should be close '
439
+ 'to %s. Try to increase the number of tuning steps.'
440
+ % (self ._chain_id , mean_accept , target_accept ))
441
+
442
+ def _check_depth (self , depth ):
443
+ if len (depth ) == 0 :
444
+ return
445
+ if np .mean (depth == self ._max_treedepth ) > 0.05 :
446
+ warnings .warn ('Chain %s reached the maximum tree depth. Increase '
447
+ 'max_treedepth, increase target_accept or '
448
+ 'reparameterize.' % self ._chain_id )
449
+
450
+ def _check_divergence (self ):
451
+ n_diverging = len (self ._divs_after_tune )
452
+ if n_diverging > 0 :
453
+ warnings .warn ("Chain %s contains %s diverging samples after "
454
+ "tuning. If increasing `target_accept` doesn't help "
455
+ "try to reparameterize."
456
+ % (self ._chain_id , n_diverging ))
457
+
458
+ def _finalize (self , strace ):
459
+ """Print warnings for obviously problematic chains."""
460
+ self ._chain_id = strace .chain
461
+
462
+ tuning = strace .get_sampler_stats ('tune' )
463
+ if tuning .ndim == 2 :
464
+ tuning = np .any (tuning , axis = - 1 )
465
+
466
+ accept = strace .get_sampler_stats ('mean_tree_accept' )
467
+ if accept .ndim == 2 :
468
+ accept = np .mean (accept , axis = - 1 )
469
+
470
+ depth = strace .get_sampler_stats ('depth' )
471
+ if depth .ndim == 2 :
472
+ depth = np .max (depth , axis = - 1 )
473
+
474
+ self ._check_len (tuning )
475
+ self ._check_depth (depth [~ tuning ])
476
+ self ._check_accept (accept [~ tuning ])
477
+ self ._check_divergence ()
0 commit comments