Skip to content

Commit 549c9bd

Browse files
committed
Small NUTS refactoring
1 parent fab6938 commit 549c9bd

File tree

3 files changed

+142
-77
lines changed

3 files changed

+142
-77
lines changed

pymc3/sampling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -388,16 +388,16 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
388388
yield strace
389389
except KeyboardInterrupt:
390390
strace.close()
391-
if hasattr(step, 'check_trace'):
392-
step.check_trace(strace)
391+
if hasattr(step, 'report'):
392+
step.report._finalize(strace)
393393
raise
394394
except BaseException:
395395
strace.close()
396396
raise
397397
else:
398398
strace.close()
399-
if hasattr(step, 'check_trace'):
400-
step.check_trace(strace)
399+
if hasattr(step, 'report'):
400+
step.report._finalize(strace)
401401

402402

403403
def _choose_backend(trace, chain, shortcuts=None, **kwds):

pymc3/step_methods/arraystep.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66
from numpy.random import uniform
77
from enum import IntEnum, unique
88

9-
__all__ = ['ArrayStep', 'ArrayStepShared', 'metrop_select', 'Competence']
9+
__all__ = [
10+
'ArrayStep', 'ArrayStepShared', 'metrop_select',
11+
'Competence', 'SamplingError']
12+
13+
14+
class SamplingError(RuntimeError):
15+
pass
1016

1117

1218
@unique

pymc3/step_methods/hmc/nuts.py

Lines changed: 131 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from collections import namedtuple
22
import warnings
33

4-
from ..arraystep import Competence
4+
from ..arraystep import Competence, SamplingError
55
from .base_hmc import BaseHMC
66
from pymc3.theanof import floatX
77
from pymc3.vartypes import continuous_types
88

99
import numpy as np
1010
import numpy.random as nr
11-
from scipy import stats
11+
from scipy import stats, linalg
12+
import six
1213

1314
__all__ = ['NUTS']
1415

@@ -87,7 +88,7 @@ class NUTS(BaseHMC):
8788

8889
def __init__(self, vars=None, Emax=1000, target_accept=0.8,
8990
gamma=0.05, k=0.75, t0=10, adapt_step_size=True,
90-
max_treedepth=10, **kwargs):
91+
max_treedepth=10, on_error='summary', **kwargs):
9192
R"""
9293
Parameters
9394
----------
@@ -124,6 +125,12 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
124125
this will be interpreded as the mass or covariance matrix.
125126
is_cov : bool, default=False
126127
Treat the scaling as mass or covariance matrix.
128+
on_error : {'summary', 'warn', 'raise'}, default='summary'
129+
How to report problems during sampling.
130+
131+
* `summary`: Print one warning after sampling.
132+
* `warn`: Print individual warnings as soon as they appear.
133+
* `raise`: Raise an error on the first problem.
127134
potential : Potential, optional
128135
An object that represents the Hamiltonian with methods `velocity`,
129136
`energy`, and `random` methods. It can be specified instead
@@ -156,11 +163,14 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
156163
self.max_treedepth = max_treedepth
157164

158165
self.tune = True
166+
self.report = NutsReport(on_error, max_treedepth, target_accept)
159167

160168
def astep(self, q0):
161169
p0 = self.potential.random()
162170
v0 = self.compute_velocity(p0)
163171
start_energy = self.compute_energy(q0, p0)
172+
if not np.isfinite(start_energy):
173+
raise ValueError('The initial energy is inf or nan.')
164174

165175
if not self.adapt_step_size:
166176
step_size = self.step_size
@@ -170,14 +180,16 @@ def astep(self, q0):
170180
step_size = np.exp(self.log_step_size_bar)
171181

172182
start = Edge(q0, p0, v0, self.dlogp(q0), start_energy)
173-
tree = Tree(len(p0), self.leapfrog, start, step_size, self.Emax)
183+
tree = _Tree(len(p0), self.leapfrog, start, step_size, self.Emax)
174184

175185
for _ in range(self.max_treedepth):
176186
direction = logbern(np.log(0.5)) * 2 - 1
177187
diverging, turning = tree.extend(direction)
178188
q = tree.proposal.q
179189

180190
if diverging or turning:
191+
if diverging:
192+
self.report._add_divergence(self.tune, *diverging)
181193
break
182194

183195
w = 1. / (self.m + self.t0)
@@ -208,64 +220,6 @@ def competence(var):
208220
return Competence.IDEAL
209221
return Competence.INCOMPATIBLE
210222

211-
def check_trace(self, strace):
212-
"""Print warnings for obviously problematic chains."""
213-
n = len(strace)
214-
chain = strace.chain
215-
216-
diverging = strace.get_sampler_stats('diverging')
217-
if diverging.ndim == 2:
218-
diverging = np.any(diverging, axis=-1)
219-
220-
tuning = strace.get_sampler_stats('tune')
221-
if tuning.ndim == 2:
222-
tuning = np.any(tuning, axis=-1)
223-
224-
accept = strace.get_sampler_stats('mean_tree_accept')
225-
if accept.ndim == 2:
226-
accept = np.mean(accept, axis=-1)
227-
228-
depth = strace.get_sampler_stats('depth')
229-
if depth.ndim == 2:
230-
depth = np.max(depth, axis=-1)
231-
232-
n_samples = n - (~tuning).sum()
233-
234-
if n < 1000:
235-
warnings.warn('Chain %s contains only %s samples.' % (chain, n))
236-
if np.all(tuning):
237-
warnings.warn('Step size tuning was enabled throughout the whole '
238-
'trace. You might want to specify the number of '
239-
'tuning steps.')
240-
if np.all(diverging):
241-
warnings.warn('Chain %s contains only diverging samples. '
242-
'The model is probably misspecified.' % chain)
243-
return
244-
if np.any(diverging[~tuning]):
245-
warnings.warn("Chain %s contains diverging samples after tuning. "
246-
"If increasing `target_accept` doesn't help, "
247-
"try to reparameterize." % chain)
248-
if n_samples > 0:
249-
depth_samples = depth[~tuning]
250-
else:
251-
depth_samples = depth[n // 2:]
252-
if np.mean(depth_samples == self.max_treedepth) > 0.05:
253-
warnings.warn('Chain %s reached the maximum tree depth. Increase '
254-
'max_treedepth, increase target_accept or '
255-
'reparameterize.' % chain)
256-
257-
mean_accept = np.mean(accept[~tuning])
258-
target_accept = self.target_accept
259-
# Try to find a reasonable interval for acceptable acceptance
260-
# probabilities. Finding this was mostry trial and error.
261-
n_bound = min(100, n)
262-
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
263-
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
264-
if target_accept < lower or target_accept > upper:
265-
warnings.warn('The acceptance probability in chain %s does not '
266-
'match the target. It is %s, but should be close '
267-
'to %s. Try to increase the number of tuning steps.'
268-
% (chain, mean_accept, target_accept))
269223

270224
# A node in the NUTS tree that is at the far right or left of the tree
271225
Edge = namedtuple("Edge", 'q, p, v, q_grad, energy')
@@ -279,7 +233,7 @@ def check_trace(self, strace):
279233
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals")
280234

281235

282-
class Tree(object):
236+
class _Tree(object):
283237
def __init__(self, ndim, leapfrog, start, step_size, Emax):
284238
"""Binary tree from the NUTS algorithm.
285239
@@ -352,24 +306,41 @@ def extend(self, direction):
352306

353307
return diverging, turning
354308

355-
def _build_subtree(self, left, depth, epsilon):
356-
if depth == 0:
309+
def _single_step(self, left, epsilon):
310+
"""Perform a leapfrog step and handle error cases."""
311+
try:
357312
right = self.leapfrog(left.q, left.p, left.q_grad, epsilon)
313+
except linalg.LinalgError as error:
314+
error_msg = "LinAlgError during leapfrog step."
315+
except ValueError as error:
316+
# Raised by many scipy.linalg functions
317+
if error.args[0].lower() == 'array must not contain infs or nans':
318+
error_msg = "Infs or nans in scipy.linalg during leapfrog step."
319+
else:
320+
raise
321+
else:
358322
right = Edge(*right)
359323
energy_change = right.energy - self.start_energy
360324
if np.isnan(energy_change):
361325
energy_change = np.inf
362326

363327
if np.abs(energy_change) > np.abs(self.max_energy_change):
364328
self.max_energy_change = energy_change
365-
p_accept = min(1, np.exp(-energy_change))
366-
367-
log_size = -energy_change
368-
diverging = energy_change > self.Emax
329+
if np.abs(energy_change) < self.Emax:
330+
p_accept = min(1, np.exp(-energy_change))
331+
log_size = -energy_change
332+
proposal = Proposal(right.q, right.energy, p_accept)
333+
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
334+
return tree, False, False
335+
else:
336+
error_msg = "Bad energy after leapfrog step."
337+
error = None
338+
tree = Subtree(None, None, None, None, -np.inf, 0, 1)
339+
return tree, (error_msg, error), False
369340

370-
proposal = Proposal(right.q, right.energy, p_accept)
371-
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
372-
return tree, diverging, False
341+
def _build_subtree(self, left, depth, epsilon):
342+
if depth == 0:
343+
return self._single_step(left, epsilon)
373344

374345
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
375346
if diverging or turning:
@@ -408,3 +379,91 @@ def stats(self):
408379
'tree_size': self.n_proposals,
409380
'max_energy_error': self.max_energy_change,
410381
}
382+
383+
384+
class NutsReport(object):
385+
def __init__(self, on_error, max_treedepth, target_accept):
386+
if on_error not in ['summary', 'raise', 'warn']:
387+
raise ValueError('Invalid value for on_error.')
388+
self._on_error = on_error
389+
self._max_treedepth = max_treedepth
390+
self._target_accept = target_accept
391+
self._chain_id = None
392+
self._divs_tune = []
393+
self._divs_after_tune = []
394+
395+
def _add_divergence(self, tuning, msg, error=None):
396+
if tuning:
397+
self._divs_tune.append((msg, error))
398+
else:
399+
self._divs_after_tune((msg, error))
400+
if self._on_error == 'raise':
401+
err = SamplingError('Divergence after tuning: ' + msg)
402+
six.raise_from(err, error)
403+
elif self._on_error == 'warn':
404+
warnings.warn('Divergence detected: ' + msg)
405+
406+
def _check_len(self, tuning):
407+
n = (~tuning).sum()
408+
if n < 1000:
409+
warnings.warn('Chain %s contains only %s samples.'
410+
% (self._chain_id, n))
411+
if np.all(tuning):
412+
warnings.warn('Step size tuning was enabled throughout the whole '
413+
'trace. You might want to specify the number of '
414+
'tuning steps.')
415+
if n == len(self._divs_after_tune):
416+
warnings.warn('Chain %s contains only diverging samples. '
417+
'The model is probably misspecified.'
418+
% self._chain_id)
419+
420+
def _check_accept(self, accept):
421+
mean_accept = np.mean(accept)
422+
target_accept = self._target_accept
423+
# Try to find a reasonable interval for acceptable acceptance
424+
# probabilities. Finding this was mostry trial and error.
425+
n_bound = min(100, len(accept))
426+
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
427+
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
428+
if target_accept < lower or target_accept > upper:
429+
warnings.warn('The acceptance probability in chain %s does not '
430+
'match the target. It is %s, but should be close '
431+
'to %s. Try to increase the number of tuning steps.'
432+
% (self._chain_id, mean_accept, target_accept))
433+
434+
def _check_depth(self, depth):
435+
if len(depth) == 0:
436+
return
437+
if np.mean(depth == self._max_treedepth) > 0.05:
438+
warnings.warn('Chain %s reached the maximum tree depth. Increase '
439+
'max_treedepth, increase target_accept or '
440+
'reparameterize.' % self._chain_id)
441+
442+
def _check_divergence(self):
443+
n_diverging = len(self._divs_after_tune)
444+
if n_diverging > 0:
445+
warnings.warn("Chain %s contains %s diverging samples after "
446+
"tuning. If increasing `target_accept` doesn't help "
447+
"try to reparameterize."
448+
% (self._chain_id, n_diverging))
449+
450+
def _finalize(self, strace):
451+
"""Print warnings for obviously problematic chains."""
452+
self._chain_id = strace.chain
453+
454+
tuning = strace.get_sampler_stats('tune')
455+
if tuning.ndim == 2:
456+
tuning = np.any(tuning, axis=-1)
457+
458+
accept = strace.get_sampler_stats('mean_tree_accept')
459+
if accept.ndim == 2:
460+
accept = np.mean(accept, axis=-1)
461+
462+
depth = strace.get_sampler_stats('depth')
463+
if depth.ndim == 2:
464+
depth = np.max(depth, axis=-1)
465+
466+
self._check_len(tuning)
467+
self._check_depth(depth[~tuning])
468+
self._check_accept(accept[~tuning])
469+
self._check_divergence()

0 commit comments

Comments
 (0)