Skip to content

Commit c306452

Browse files
aseyboldtColCarroll
authored andcommitted
Small NUTS refactor to improve error handling (#2215)
* Small NUTS refactoring * Save position of divergence in nuts * Implement review suggestions * Add tests for nuts reports * Improve nuts report error messages
1 parent c546ad2 commit c306452

File tree

7 files changed

+193
-81
lines changed

7 files changed

+193
-81
lines changed

pymc3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .tuning import *
1616
from .variational import *
1717
from .vartypes import *
18+
from .exceptions import *
1819
from . import sampling
1920

2021
from .debug import *

pymc3/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__all__ = ['SamplingError']
2+
3+
4+
class SamplingError(RuntimeError):
5+
pass

pymc3/sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,16 +388,16 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
388388
yield strace
389389
except KeyboardInterrupt:
390390
strace.close()
391-
if hasattr(step, 'check_trace'):
392-
step.check_trace(strace)
391+
if hasattr(step, 'report'):
392+
step.report._finalize(strace)
393393
raise
394394
except BaseException:
395395
strace.close()
396396
raise
397397
else:
398398
strace.close()
399-
if hasattr(step, 'check_trace'):
400-
step.check_trace(strace)
399+
if hasattr(step, 'report'):
400+
step.report._finalize(strace)
401401

402402

403403
def _choose_backend(trace, chain, shortcuts=None, **kwds):

pymc3/step_methods/arraystep.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
from numpy.random import uniform
77
from enum import IntEnum, unique
88

9-
__all__ = ['ArrayStep', 'ArrayStepShared', 'metrop_select', 'Competence']
9+
__all__ = [
10+
'ArrayStep', 'ArrayStepShared', 'metrop_select', 'Competence']
1011

1112

1213
@unique

pymc3/step_methods/compound.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@ def step(self, point):
3131
point = method.step(point)
3232
return point
3333

34-
def check_trace(self, trace):
34+
@property
35+
def report(self):
36+
reports = []
3537
for method in self.methods:
36-
if hasattr(method, 'check_trace'):
37-
method.check_trace(trace)
38+
if hasattr(method, 'report'):
39+
reports.append(method.report)
40+
return _CompoundReport(reports)
41+
42+
43+
class _CompoundReport(object):
44+
def __init__(self, reports):
45+
self._reports = reports
46+
47+
def _finalize(self, strace):
48+
for report in self._reports:
49+
report._finalize(strace)

pymc3/step_methods/hmc/nuts.py

Lines changed: 138 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22
import warnings
33

44
from ..arraystep import Competence
5+
from pymc3.exceptions import SamplingError
56
from .base_hmc import BaseHMC
67
from pymc3.theanof import floatX
78
from pymc3.vartypes import continuous_types
89

910
import numpy as np
1011
import numpy.random as nr
11-
from scipy import stats
12+
from scipy import stats, linalg
13+
import six
1214

1315
__all__ = ['NUTS']
1416

@@ -87,7 +89,7 @@ class NUTS(BaseHMC):
8789

8890
def __init__(self, vars=None, Emax=1000, target_accept=0.8,
8991
gamma=0.05, k=0.75, t0=10, adapt_step_size=True,
90-
max_treedepth=10, **kwargs):
92+
max_treedepth=10, on_error='summary', **kwargs):
9193
R"""
9294
Parameters
9395
----------
@@ -124,6 +126,12 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
124126
this will be interpreded as the mass or covariance matrix.
125127
is_cov : bool, default=False
126128
Treat the scaling as mass or covariance matrix.
129+
on_error : {'summary', 'warn', 'raise'}, default='summary'
130+
How to report problems during sampling.
131+
132+
* `summary`: Print one warning after sampling.
133+
* `warn`: Print individual warnings as soon as they appear.
134+
* `raise`: Raise an error on the first problem.
127135
potential : Potential, optional
128136
An object that represents the Hamiltonian with methods `velocity`,
129137
`energy`, and `random` methods. It can be specified instead
@@ -156,11 +164,15 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
156164
self.max_treedepth = max_treedepth
157165

158166
self.tune = True
167+
self.report = NutsReport(on_error, max_treedepth, target_accept)
159168

160169
def astep(self, q0):
161170
p0 = self.potential.random()
162171
v0 = self.compute_velocity(p0)
163172
start_energy = self.compute_energy(q0, p0)
173+
if not np.isfinite(start_energy):
174+
raise ValueError('Bad initial energy: %s. The model '
175+
'might be misspecified.' % start_energy)
164176

165177
if not self.adapt_step_size:
166178
step_size = self.step_size
@@ -170,14 +182,16 @@ def astep(self, q0):
170182
step_size = np.exp(self.log_step_size_bar)
171183

172184
start = Edge(q0, p0, v0, self.dlogp(q0), start_energy)
173-
tree = Tree(len(p0), self.leapfrog, start, step_size, self.Emax)
185+
tree = _Tree(len(p0), self.leapfrog, start, step_size, self.Emax)
174186

175187
for _ in range(self.max_treedepth):
176188
direction = logbern(np.log(0.5)) * 2 - 1
177189
diverging, turning = tree.extend(direction)
178190
q = tree.proposal.q
179191

180192
if diverging or turning:
193+
if diverging:
194+
self.report._add_divergence(self.tune, *diverging)
181195
break
182196

183197
w = 1. / (self.m + self.t0)
@@ -208,64 +222,6 @@ def competence(var):
208222
return Competence.IDEAL
209223
return Competence.INCOMPATIBLE
210224

211-
def check_trace(self, strace):
212-
"""Print warnings for obviously problematic chains."""
213-
n = len(strace)
214-
chain = strace.chain
215-
216-
diverging = strace.get_sampler_stats('diverging')
217-
if diverging.ndim == 2:
218-
diverging = np.any(diverging, axis=-1)
219-
220-
tuning = strace.get_sampler_stats('tune')
221-
if tuning.ndim == 2:
222-
tuning = np.any(tuning, axis=-1)
223-
224-
accept = strace.get_sampler_stats('mean_tree_accept')
225-
if accept.ndim == 2:
226-
accept = np.mean(accept, axis=-1)
227-
228-
depth = strace.get_sampler_stats('depth')
229-
if depth.ndim == 2:
230-
depth = np.max(depth, axis=-1)
231-
232-
n_samples = n - (~tuning).sum()
233-
234-
if n < 1000:
235-
warnings.warn('Chain %s contains only %s samples.' % (chain, n))
236-
if np.all(tuning):
237-
warnings.warn('Step size tuning was enabled throughout the whole '
238-
'trace. You might want to specify the number of '
239-
'tuning steps.')
240-
if np.all(diverging):
241-
warnings.warn('Chain %s contains only diverging samples. '
242-
'The model is probably misspecified.' % chain)
243-
return
244-
if np.any(diverging[~tuning]):
245-
warnings.warn("Chain %s contains diverging samples after tuning. "
246-
"If increasing `target_accept` doesn't help, "
247-
"try to reparameterize." % chain)
248-
if n_samples > 0:
249-
depth_samples = depth[~tuning]
250-
else:
251-
depth_samples = depth[n // 2:]
252-
if np.mean(depth_samples == self.max_treedepth) > 0.05:
253-
warnings.warn('Chain %s reached the maximum tree depth. Increase '
254-
'max_treedepth, increase target_accept or '
255-
'reparameterize.' % chain)
256-
257-
mean_accept = np.mean(accept[~tuning])
258-
target_accept = self.target_accept
259-
# Try to find a reasonable interval for acceptable acceptance
260-
# probabilities. Finding this was mostry trial and error.
261-
n_bound = min(100, n)
262-
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
263-
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
264-
if target_accept < lower or target_accept > upper:
265-
warnings.warn('The acceptance probability in chain %s does not '
266-
'match the target. It is %s, but should be close '
267-
'to %s. Try to increase the number of tuning steps.'
268-
% (chain, mean_accept, target_accept))
269225

270226
# A node in the NUTS tree that is at the far right or left of the tree
271227
Edge = namedtuple("Edge", 'q, p, v, q_grad, energy')
@@ -279,7 +235,7 @@ def check_trace(self, strace):
279235
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals")
280236

281237

282-
class Tree(object):
238+
class _Tree(object):
283239
def __init__(self, ndim, leapfrog, start, step_size, Emax):
284240
"""Binary tree from the NUTS algorithm.
285241
@@ -352,24 +308,45 @@ def extend(self, direction):
352308

353309
return diverging, turning
354310

355-
def _build_subtree(self, left, depth, epsilon):
356-
if depth == 0:
311+
def _single_step(self, left, epsilon):
312+
"""Perform a leapfrog step and handle error cases."""
313+
try:
357314
right = self.leapfrog(left.q, left.p, left.q_grad, epsilon)
315+
except linalg.LinAlgError as err:
316+
error_msg = "LinAlgError during leapfrog step."
317+
error = err
318+
except ValueError as err:
319+
# Raised by many scipy.linalg functions
320+
scipy_msg = "array must not contain infs or nans"
321+
if len(err.args) > 0 and scipy_msg in err.args[0].lower():
322+
error_msg = "Infs or nans in scipy.linalg during leapfrog step."
323+
error = err
324+
else:
325+
raise
326+
else:
358327
right = Edge(*right)
359328
energy_change = right.energy - self.start_energy
360329
if np.isnan(energy_change):
361330
energy_change = np.inf
362331

363332
if np.abs(energy_change) > np.abs(self.max_energy_change):
364333
self.max_energy_change = energy_change
365-
p_accept = min(1, np.exp(-energy_change))
366-
367-
log_size = -energy_change
368-
diverging = energy_change > self.Emax
334+
if np.abs(energy_change) < self.Emax:
335+
p_accept = min(1, np.exp(-energy_change))
336+
log_size = -energy_change
337+
proposal = Proposal(right.q, right.energy, p_accept)
338+
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
339+
return tree, False, False
340+
else:
341+
error_msg = ("Energy change in leapfrog step is too large: %s. "
342+
% energy_change)
343+
error = None
344+
tree = Subtree(None, None, None, None, -np.inf, 0, 1)
345+
return tree, (error_msg, error, left), False
369346

370-
proposal = Proposal(right.q, right.energy, p_accept)
371-
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
372-
return tree, diverging, False
347+
def _build_subtree(self, left, depth, epsilon):
348+
if depth == 0:
349+
return self._single_step(left, epsilon)
373350

374351
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
375352
if diverging or turning:
@@ -408,3 +385,93 @@ def stats(self):
408385
'tree_size': self.n_proposals,
409386
'max_energy_error': self.max_energy_change,
410387
}
388+
389+
390+
class NutsReport(object):
391+
def __init__(self, on_error, max_treedepth, target_accept):
392+
if on_error not in ['summary', 'raise', 'warn']:
393+
raise ValueError('Invalid value for on_error.')
394+
self._on_error = on_error
395+
self._max_treedepth = max_treedepth
396+
self._target_accept = target_accept
397+
self._chain_id = None
398+
self._divs_tune = []
399+
self._divs_after_tune = []
400+
401+
def _add_divergence(self, tuning, msg, error=None, point=None):
402+
if tuning:
403+
self._divs_tune.append((msg, error, point))
404+
else:
405+
self._divs_after_tune.append((msg, error, point))
406+
if self._on_error == 'raise':
407+
err = SamplingError('Divergence after tuning: %s Increase '
408+
'target_accept or reparameterize.' % msg)
409+
six.raise_from(err, error)
410+
elif self._on_error == 'warn':
411+
warnings.warn('Divergence detected: %s Increase target_accept '
412+
'or reparameterize.' % msg)
413+
414+
def _check_len(self, tuning):
415+
n = (~tuning).sum()
416+
if n < 500:
417+
warnings.warn('Chain %s contains only %s samples.'
418+
% (self._chain_id, n))
419+
if np.all(tuning):
420+
warnings.warn('Step size tuning was enabled throughout the whole '
421+
'trace. You might want to specify the number of '
422+
'tuning steps.')
423+
if n == len(self._divs_after_tune):
424+
warnings.warn('Chain %s contains only diverging samples. '
425+
'The model is probably misspecified.'
426+
% self._chain_id)
427+
428+
def _check_accept(self, accept):
429+
mean_accept = np.mean(accept)
430+
target_accept = self._target_accept
431+
# Try to find a reasonable interval for acceptable acceptance
432+
# probabilities. Finding this was mostry trial and error.
433+
n_bound = min(100, len(accept))
434+
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
435+
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
436+
if target_accept < lower or target_accept > upper:
437+
warnings.warn('The acceptance probability in chain %s does not '
438+
'match the target. It is %s, but should be close '
439+
'to %s. Try to increase the number of tuning steps.'
440+
% (self._chain_id, mean_accept, target_accept))
441+
442+
def _check_depth(self, depth):
443+
if len(depth) == 0:
444+
return
445+
if np.mean(depth == self._max_treedepth) > 0.05:
446+
warnings.warn('Chain %s reached the maximum tree depth. Increase '
447+
'max_treedepth, increase target_accept or '
448+
'reparameterize.' % self._chain_id)
449+
450+
def _check_divergence(self):
451+
n_diverging = len(self._divs_after_tune)
452+
if n_diverging > 0:
453+
warnings.warn("Chain %s contains %s diverging samples after "
454+
"tuning. If increasing `target_accept` doesn't help "
455+
"try to reparameterize."
456+
% (self._chain_id, n_diverging))
457+
458+
def _finalize(self, strace):
459+
"""Print warnings for obviously problematic chains."""
460+
self._chain_id = strace.chain
461+
462+
tuning = strace.get_sampler_stats('tune')
463+
if tuning.ndim == 2:
464+
tuning = np.any(tuning, axis=-1)
465+
466+
accept = strace.get_sampler_stats('mean_tree_accept')
467+
if accept.ndim == 2:
468+
accept = np.mean(accept, axis=-1)
469+
470+
depth = strace.get_sampler_stats('depth')
471+
if depth.ndim == 2:
472+
depth = np.max(depth, axis=-1)
473+
474+
self._check_len(tuning)
475+
self._check_depth(depth[~tuning])
476+
self._check_accept(accept[~tuning])
477+
self._check_divergence()

0 commit comments

Comments
 (0)