Skip to content

Small NUTS refactor to improve error handling #2215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,16 +388,16 @@ def _iter_sample(draws, step, start=None, trace=None, chain=0, tune=None,
yield strace
except KeyboardInterrupt:
strace.close()
if hasattr(step, 'check_trace'):
step.check_trace(strace)
if hasattr(step, 'report'):
step.report._finalize(strace)
raise
except BaseException:
strace.close()
raise
else:
strace.close()
if hasattr(step, 'check_trace'):
step.check_trace(strace)
if hasattr(step, 'report'):
step.report._finalize(strace)


def _choose_backend(trace, chain, shortcuts=None, **kwds):
Expand Down
8 changes: 7 additions & 1 deletion pymc3/step_methods/arraystep.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,13 @@
from numpy.random import uniform
from enum import IntEnum, unique

__all__ = ['ArrayStep', 'ArrayStepShared', 'metrop_select', 'Competence']
__all__ = [
'ArrayStep', 'ArrayStepShared', 'metrop_select',
'Competence', 'SamplingError']


class SamplingError(RuntimeError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 we could probably do with either pymc3.exceptions or pymc3.sampling.exceptions soon

pass


@unique
Expand Down
203 changes: 131 additions & 72 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from collections import namedtuple
import warnings

from ..arraystep import Competence
from ..arraystep import Competence, SamplingError
from .base_hmc import BaseHMC
from pymc3.theanof import floatX
from pymc3.vartypes import continuous_types

import numpy as np
import numpy.random as nr
from scipy import stats
from scipy import stats, linalg
import six

__all__ = ['NUTS']

Expand Down Expand Up @@ -87,7 +88,7 @@ class NUTS(BaseHMC):

def __init__(self, vars=None, Emax=1000, target_accept=0.8,
gamma=0.05, k=0.75, t0=10, adapt_step_size=True,
max_treedepth=10, **kwargs):
max_treedepth=10, on_error='summary', **kwargs):
R"""
Parameters
----------
Expand Down Expand Up @@ -124,6 +125,12 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
this will be interpreded as the mass or covariance matrix.
is_cov : bool, default=False
Treat the scaling as mass or covariance matrix.
on_error : {'summary', 'warn', 'raise'}, default='summary'
How to report problems during sampling.

* `summary`: Print one warning after sampling.
* `warn`: Print individual warnings as soon as they appear.
* `raise`: Raise an error on the first problem.
potential : Potential, optional
An object that represents the Hamiltonian with methods `velocity`,
`energy`, and `random` methods. It can be specified instead
Expand Down Expand Up @@ -156,11 +163,14 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8,
self.max_treedepth = max_treedepth

self.tune = True
self.report = NutsReport(on_error, max_treedepth, target_accept)

def astep(self, q0):
p0 = self.potential.random()
v0 = self.compute_velocity(p0)
start_energy = self.compute_energy(q0, p0)
if not np.isfinite(start_energy):
raise ValueError('The initial energy is inf or nan.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just print repr of start_energy here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's better, yes.


if not self.adapt_step_size:
step_size = self.step_size
Expand All @@ -170,14 +180,16 @@ def astep(self, q0):
step_size = np.exp(self.log_step_size_bar)

start = Edge(q0, p0, v0, self.dlogp(q0), start_energy)
tree = Tree(len(p0), self.leapfrog, start, step_size, self.Emax)
tree = _Tree(len(p0), self.leapfrog, start, step_size, self.Emax)

for _ in range(self.max_treedepth):
direction = logbern(np.log(0.5)) * 2 - 1
diverging, turning = tree.extend(direction)
q = tree.proposal.q

if diverging or turning:
if diverging:
self.report._add_divergence(self.tune, *diverging)
break

w = 1. / (self.m + self.t0)
Expand Down Expand Up @@ -208,64 +220,6 @@ def competence(var):
return Competence.IDEAL
return Competence.INCOMPATIBLE

def check_trace(self, strace):
"""Print warnings for obviously problematic chains."""
n = len(strace)
chain = strace.chain

diverging = strace.get_sampler_stats('diverging')
if diverging.ndim == 2:
diverging = np.any(diverging, axis=-1)

tuning = strace.get_sampler_stats('tune')
if tuning.ndim == 2:
tuning = np.any(tuning, axis=-1)

accept = strace.get_sampler_stats('mean_tree_accept')
if accept.ndim == 2:
accept = np.mean(accept, axis=-1)

depth = strace.get_sampler_stats('depth')
if depth.ndim == 2:
depth = np.max(depth, axis=-1)

n_samples = n - (~tuning).sum()

if n < 1000:
warnings.warn('Chain %s contains only %s samples.' % (chain, n))
if np.all(tuning):
warnings.warn('Step size tuning was enabled throughout the whole '
'trace. You might want to specify the number of '
'tuning steps.')
if np.all(diverging):
warnings.warn('Chain %s contains only diverging samples. '
'The model is probably misspecified.' % chain)
return
if np.any(diverging[~tuning]):
warnings.warn("Chain %s contains diverging samples after tuning. "
"If increasing `target_accept` doesn't help, "
"try to reparameterize." % chain)
if n_samples > 0:
depth_samples = depth[~tuning]
else:
depth_samples = depth[n // 2:]
if np.mean(depth_samples == self.max_treedepth) > 0.05:
warnings.warn('Chain %s reached the maximum tree depth. Increase '
'max_treedepth, increase target_accept or '
'reparameterize.' % chain)

mean_accept = np.mean(accept[~tuning])
target_accept = self.target_accept
# Try to find a reasonable interval for acceptable acceptance
# probabilities. Finding this was mostry trial and error.
n_bound = min(100, n)
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
if target_accept < lower or target_accept > upper:
warnings.warn('The acceptance probability in chain %s does not '
'match the target. It is %s, but should be close '
'to %s. Try to increase the number of tuning steps.'
% (chain, mean_accept, target_accept))

# A node in the NUTS tree that is at the far right or left of the tree
Edge = namedtuple("Edge", 'q, p, v, q_grad, energy')
Expand All @@ -279,7 +233,7 @@ def check_trace(self, strace):
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals")


class Tree(object):
class _Tree(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

def __init__(self, ndim, leapfrog, start, step_size, Emax):
"""Binary tree from the NUTS algorithm.

Expand Down Expand Up @@ -352,24 +306,41 @@ def extend(self, direction):

return diverging, turning

def _build_subtree(self, left, depth, epsilon):
if depth == 0:
def _single_step(self, left, epsilon):
"""Perform a leapfrog step and handle error cases."""
try:
right = self.leapfrog(left.q, left.p, left.q_grad, epsilon)
except linalg.LinalgError as error:
error_msg = "LinAlgError during leapfrog step."
except ValueError as error:
# Raised by many scipy.linalg functions
if error.args[0].lower() == 'array must not contain infs or nans':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a little dangerous: it is brittle to changes in scipy.linalg, and will throw an IndexError if there is no message. Can you check directly if the array contains infs or nans?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it either, but couldn't think of a better solution. We don't have access to the actual array that led to the problem, it could be hidden somewhere in the theano graph. But I'll add a check that error.args is long enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I'll add a test for this, then we should at least notice if this changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks -- would be super frustrating/confusing if this threw an IndexError :)

error_msg = "Infs or nans in scipy.linalg during leapfrog step."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would probably have to be a discontinuity in the pdf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A got a lot of those when computing gradients of mvnormal with numerically indefinite covariance matrices. I can't think of any good advice if this happens, other than check anything that involves scipy.linalg.

else:
raise
else:
right = Edge(*right)
energy_change = right.energy - self.start_energy
if np.isnan(energy_change):
energy_change = np.inf

if np.abs(energy_change) > np.abs(self.max_energy_change):
self.max_energy_change = energy_change
p_accept = min(1, np.exp(-energy_change))

log_size = -energy_change
diverging = energy_change > self.Emax
if np.abs(energy_change) < self.Emax:
p_accept = min(1, np.exp(-energy_change))
log_size = -energy_change
proposal = Proposal(right.q, right.energy, p_accept)
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
return tree, False, False
else:
error_msg = "Bad energy after leapfrog step."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like the advice after this is to make the step_scale smaller (and the message should report the current value)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The step size should be adapted already when you see this message. Increasing target_accept is I think the right thing (this will change the step size adaptation to use smaller step sizes). But otherwise I agree.

error = None
tree = Subtree(None, None, None, None, -np.inf, 0, 1)
return tree, (error_msg, error), False

proposal = Proposal(right.q, right.energy, p_accept)
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
return tree, diverging, False
def _build_subtree(self, left, depth, epsilon):
if depth == 0:
return self._single_step(left, epsilon)

tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon)
if diverging or turning:
Expand Down Expand Up @@ -408,3 +379,91 @@ def stats(self):
'tree_size': self.n_proposals,
'max_energy_error': self.max_energy_change,
}


class NutsReport(object):
def __init__(self, on_error, max_treedepth, target_accept):
if on_error not in ['summary', 'raise', 'warn']:
raise ValueError('Invalid value for on_error.')
self._on_error = on_error
self._max_treedepth = max_treedepth
self._target_accept = target_accept
self._chain_id = None
self._divs_tune = []
self._divs_after_tune = []

def _add_divergence(self, tuning, msg, error=None):
if tuning:
self._divs_tune.append((msg, error))
else:
self._divs_after_tune((msg, error))
if self._on_error == 'raise':
err = SamplingError('Divergence after tuning: ' + msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this suggests reparametrizing or smaller steps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, but larger target_accept again instead of step size.

six.raise_from(err, error)
elif self._on_error == 'warn':
warnings.warn('Divergence detected: ' + msg)

def _check_len(self, tuning):
n = (~tuning).sum()
if n < 1000:
warnings.warn('Chain %s contains only %s samples.'
% (self._chain_id, n))
if np.all(tuning):
warnings.warn('Step size tuning was enabled throughout the whole '
'trace. You might want to specify the number of '
'tuning steps.')
if n == len(self._divs_after_tune):
warnings.warn('Chain %s contains only diverging samples. '
'The model is probably misspecified.'
% self._chain_id)

def _check_accept(self, accept):
mean_accept = np.mean(accept)
target_accept = self._target_accept
# Try to find a reasonable interval for acceptable acceptance
# probabilities. Finding this was mostry trial and error.
n_bound = min(100, len(accept))
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95)
if target_accept < lower or target_accept > upper:
warnings.warn('The acceptance probability in chain %s does not '
'match the target. It is %s, but should be close '
'to %s. Try to increase the number of tuning steps.'
% (self._chain_id, mean_accept, target_accept))

def _check_depth(self, depth):
if len(depth) == 0:
return
if np.mean(depth == self._max_treedepth) > 0.05:
warnings.warn('Chain %s reached the maximum tree depth. Increase '
'max_treedepth, increase target_accept or '
'reparameterize.' % self._chain_id)

def _check_divergence(self):
n_diverging = len(self._divs_after_tune)
if n_diverging > 0:
warnings.warn("Chain %s contains %s diverging samples after "
"tuning. If increasing `target_accept` doesn't help "
"try to reparameterize."
% (self._chain_id, n_diverging))

def _finalize(self, strace):
"""Print warnings for obviously problematic chains."""
self._chain_id = strace.chain

tuning = strace.get_sampler_stats('tune')
if tuning.ndim == 2:
tuning = np.any(tuning, axis=-1)

accept = strace.get_sampler_stats('mean_tree_accept')
if accept.ndim == 2:
accept = np.mean(accept, axis=-1)

depth = strace.get_sampler_stats('depth')
if depth.ndim == 2:
depth = np.max(depth, axis=-1)

self._check_len(tuning)
self._check_depth(depth[~tuning])
self._check_accept(accept[~tuning])
self._check_divergence()