-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Small NUTS refactor to improve error handling #2215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
549c9bd
b4bade2
9a830d7
af10e7f
9daf2a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,15 @@ | ||
from collections import namedtuple | ||
import warnings | ||
|
||
from ..arraystep import Competence | ||
from ..arraystep import Competence, SamplingError | ||
from .base_hmc import BaseHMC | ||
from pymc3.theanof import floatX | ||
from pymc3.vartypes import continuous_types | ||
|
||
import numpy as np | ||
import numpy.random as nr | ||
from scipy import stats | ||
from scipy import stats, linalg | ||
import six | ||
|
||
__all__ = ['NUTS'] | ||
|
||
|
@@ -87,7 +88,7 @@ class NUTS(BaseHMC): | |
|
||
def __init__(self, vars=None, Emax=1000, target_accept=0.8, | ||
gamma=0.05, k=0.75, t0=10, adapt_step_size=True, | ||
max_treedepth=10, **kwargs): | ||
max_treedepth=10, on_error='summary', **kwargs): | ||
R""" | ||
Parameters | ||
---------- | ||
|
@@ -124,6 +125,12 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8, | |
this will be interpreded as the mass or covariance matrix. | ||
is_cov : bool, default=False | ||
Treat the scaling as mass or covariance matrix. | ||
on_error : {'summary', 'warn', 'raise'}, default='summary' | ||
How to report problems during sampling. | ||
|
||
* `summary`: Print one warning after sampling. | ||
* `warn`: Print individual warnings as soon as they appear. | ||
* `raise`: Raise an error on the first problem. | ||
potential : Potential, optional | ||
An object that represents the Hamiltonian with methods `velocity`, | ||
`energy`, and `random` methods. It can be specified instead | ||
|
@@ -156,11 +163,14 @@ def __init__(self, vars=None, Emax=1000, target_accept=0.8, | |
self.max_treedepth = max_treedepth | ||
|
||
self.tune = True | ||
self.report = NutsReport(on_error, max_treedepth, target_accept) | ||
|
||
def astep(self, q0): | ||
p0 = self.potential.random() | ||
v0 = self.compute_velocity(p0) | ||
start_energy = self.compute_energy(q0, p0) | ||
if not np.isfinite(start_energy): | ||
raise ValueError('The initial energy is inf or nan.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just print There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that's better, yes. |
||
|
||
if not self.adapt_step_size: | ||
step_size = self.step_size | ||
|
@@ -170,14 +180,16 @@ def astep(self, q0): | |
step_size = np.exp(self.log_step_size_bar) | ||
|
||
start = Edge(q0, p0, v0, self.dlogp(q0), start_energy) | ||
tree = Tree(len(p0), self.leapfrog, start, step_size, self.Emax) | ||
tree = _Tree(len(p0), self.leapfrog, start, step_size, self.Emax) | ||
|
||
for _ in range(self.max_treedepth): | ||
direction = logbern(np.log(0.5)) * 2 - 1 | ||
diverging, turning = tree.extend(direction) | ||
q = tree.proposal.q | ||
|
||
if diverging or turning: | ||
if diverging: | ||
self.report._add_divergence(self.tune, *diverging) | ||
break | ||
|
||
w = 1. / (self.m + self.t0) | ||
|
@@ -208,64 +220,6 @@ def competence(var): | |
return Competence.IDEAL | ||
return Competence.INCOMPATIBLE | ||
|
||
def check_trace(self, strace): | ||
"""Print warnings for obviously problematic chains.""" | ||
n = len(strace) | ||
chain = strace.chain | ||
|
||
diverging = strace.get_sampler_stats('diverging') | ||
if diverging.ndim == 2: | ||
diverging = np.any(diverging, axis=-1) | ||
|
||
tuning = strace.get_sampler_stats('tune') | ||
if tuning.ndim == 2: | ||
tuning = np.any(tuning, axis=-1) | ||
|
||
accept = strace.get_sampler_stats('mean_tree_accept') | ||
if accept.ndim == 2: | ||
accept = np.mean(accept, axis=-1) | ||
|
||
depth = strace.get_sampler_stats('depth') | ||
if depth.ndim == 2: | ||
depth = np.max(depth, axis=-1) | ||
|
||
n_samples = n - (~tuning).sum() | ||
|
||
if n < 1000: | ||
warnings.warn('Chain %s contains only %s samples.' % (chain, n)) | ||
if np.all(tuning): | ||
warnings.warn('Step size tuning was enabled throughout the whole ' | ||
'trace. You might want to specify the number of ' | ||
'tuning steps.') | ||
if np.all(diverging): | ||
warnings.warn('Chain %s contains only diverging samples. ' | ||
'The model is probably misspecified.' % chain) | ||
return | ||
if np.any(diverging[~tuning]): | ||
warnings.warn("Chain %s contains diverging samples after tuning. " | ||
"If increasing `target_accept` doesn't help, " | ||
"try to reparameterize." % chain) | ||
if n_samples > 0: | ||
depth_samples = depth[~tuning] | ||
else: | ||
depth_samples = depth[n // 2:] | ||
if np.mean(depth_samples == self.max_treedepth) > 0.05: | ||
warnings.warn('Chain %s reached the maximum tree depth. Increase ' | ||
'max_treedepth, increase target_accept or ' | ||
'reparameterize.' % chain) | ||
|
||
mean_accept = np.mean(accept[~tuning]) | ||
target_accept = self.target_accept | ||
# Try to find a reasonable interval for acceptable acceptance | ||
# probabilities. Finding this was mostry trial and error. | ||
n_bound = min(100, n) | ||
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound | ||
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95) | ||
if target_accept < lower or target_accept > upper: | ||
warnings.warn('The acceptance probability in chain %s does not ' | ||
'match the target. It is %s, but should be close ' | ||
'to %s. Try to increase the number of tuning steps.' | ||
% (chain, mean_accept, target_accept)) | ||
|
||
# A node in the NUTS tree that is at the far right or left of the tree | ||
Edge = namedtuple("Edge", 'q, p, v, q_grad, energy') | ||
|
@@ -279,7 +233,7 @@ def check_trace(self, strace): | |
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals") | ||
|
||
|
||
class Tree(object): | ||
class _Tree(object): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👍 |
||
def __init__(self, ndim, leapfrog, start, step_size, Emax): | ||
"""Binary tree from the NUTS algorithm. | ||
|
||
|
@@ -352,24 +306,41 @@ def extend(self, direction): | |
|
||
return diverging, turning | ||
|
||
def _build_subtree(self, left, depth, epsilon): | ||
if depth == 0: | ||
def _single_step(self, left, epsilon): | ||
"""Perform a leapfrog step and handle error cases.""" | ||
try: | ||
right = self.leapfrog(left.q, left.p, left.q_grad, epsilon) | ||
except linalg.LinalgError as error: | ||
error_msg = "LinAlgError during leapfrog step." | ||
except ValueError as error: | ||
# Raised by many scipy.linalg functions | ||
if error.args[0].lower() == 'array must not contain infs or nans': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is a little dangerous: it is brittle to changes in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like it either, but couldn't think of a better solution. We don't have access to the actual array that led to the problem, it could be hidden somewhere in the theano graph. But I'll add a check that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And I'll add a test for this, then we should at least notice if this changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks -- would be super frustrating/confusing if this threw an |
||
error_msg = "Infs or nans in scipy.linalg during leapfrog step." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this would probably have to be a discontinuity in the pdf? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A got a lot of those when computing gradients of mvnormal with numerically indefinite covariance matrices. I can't think of any good advice if this happens, other than check anything that involves |
||
else: | ||
raise | ||
else: | ||
right = Edge(*right) | ||
energy_change = right.energy - self.start_energy | ||
if np.isnan(energy_change): | ||
energy_change = np.inf | ||
|
||
if np.abs(energy_change) > np.abs(self.max_energy_change): | ||
self.max_energy_change = energy_change | ||
p_accept = min(1, np.exp(-energy_change)) | ||
|
||
log_size = -energy_change | ||
diverging = energy_change > self.Emax | ||
if np.abs(energy_change) < self.Emax: | ||
p_accept = min(1, np.exp(-energy_change)) | ||
log_size = -energy_change | ||
proposal = Proposal(right.q, right.energy, p_accept) | ||
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1) | ||
return tree, False, False | ||
else: | ||
error_msg = "Bad energy after leapfrog step." | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it seems like the advice after this is to make the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The step size should be adapted already when you see this message. Increasing |
||
error = None | ||
tree = Subtree(None, None, None, None, -np.inf, 0, 1) | ||
return tree, (error_msg, error), False | ||
|
||
proposal = Proposal(right.q, right.energy, p_accept) | ||
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1) | ||
return tree, diverging, False | ||
def _build_subtree(self, left, depth, epsilon): | ||
if depth == 0: | ||
return self._single_step(left, epsilon) | ||
|
||
tree1, diverging, turning = self._build_subtree(left, depth - 1, epsilon) | ||
if diverging or turning: | ||
|
@@ -408,3 +379,91 @@ def stats(self): | |
'tree_size': self.n_proposals, | ||
'max_energy_error': self.max_energy_change, | ||
} | ||
|
||
|
||
class NutsReport(object): | ||
def __init__(self, on_error, max_treedepth, target_accept): | ||
if on_error not in ['summary', 'raise', 'warn']: | ||
raise ValueError('Invalid value for on_error.') | ||
self._on_error = on_error | ||
self._max_treedepth = max_treedepth | ||
self._target_accept = target_accept | ||
self._chain_id = None | ||
self._divs_tune = [] | ||
self._divs_after_tune = [] | ||
|
||
def _add_divergence(self, tuning, msg, error=None): | ||
if tuning: | ||
self._divs_tune.append((msg, error)) | ||
else: | ||
self._divs_after_tune((msg, error)) | ||
if self._on_error == 'raise': | ||
err = SamplingError('Divergence after tuning: ' + msg) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this suggests reparametrizing or smaller steps? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, but larger |
||
six.raise_from(err, error) | ||
elif self._on_error == 'warn': | ||
warnings.warn('Divergence detected: ' + msg) | ||
|
||
def _check_len(self, tuning): | ||
n = (~tuning).sum() | ||
if n < 1000: | ||
warnings.warn('Chain %s contains only %s samples.' | ||
% (self._chain_id, n)) | ||
if np.all(tuning): | ||
warnings.warn('Step size tuning was enabled throughout the whole ' | ||
'trace. You might want to specify the number of ' | ||
'tuning steps.') | ||
if n == len(self._divs_after_tune): | ||
warnings.warn('Chain %s contains only diverging samples. ' | ||
'The model is probably misspecified.' | ||
% self._chain_id) | ||
|
||
def _check_accept(self, accept): | ||
mean_accept = np.mean(accept) | ||
target_accept = self._target_accept | ||
# Try to find a reasonable interval for acceptable acceptance | ||
# probabilities. Finding this was mostry trial and error. | ||
n_bound = min(100, len(accept)) | ||
n_good, n_bad = mean_accept * n_bound, (1 - mean_accept) * n_bound | ||
lower, upper = stats.beta(n_good + 1, n_bad + 1).interval(0.95) | ||
if target_accept < lower or target_accept > upper: | ||
warnings.warn('The acceptance probability in chain %s does not ' | ||
'match the target. It is %s, but should be close ' | ||
'to %s. Try to increase the number of tuning steps.' | ||
% (self._chain_id, mean_accept, target_accept)) | ||
|
||
def _check_depth(self, depth): | ||
if len(depth) == 0: | ||
return | ||
if np.mean(depth == self._max_treedepth) > 0.05: | ||
warnings.warn('Chain %s reached the maximum tree depth. Increase ' | ||
'max_treedepth, increase target_accept or ' | ||
'reparameterize.' % self._chain_id) | ||
|
||
def _check_divergence(self): | ||
n_diverging = len(self._divs_after_tune) | ||
if n_diverging > 0: | ||
warnings.warn("Chain %s contains %s diverging samples after " | ||
"tuning. If increasing `target_accept` doesn't help " | ||
"try to reparameterize." | ||
% (self._chain_id, n_diverging)) | ||
|
||
def _finalize(self, strace): | ||
"""Print warnings for obviously problematic chains.""" | ||
self._chain_id = strace.chain | ||
|
||
tuning = strace.get_sampler_stats('tune') | ||
if tuning.ndim == 2: | ||
tuning = np.any(tuning, axis=-1) | ||
|
||
accept = strace.get_sampler_stats('mean_tree_accept') | ||
if accept.ndim == 2: | ||
accept = np.mean(accept, axis=-1) | ||
|
||
depth = strace.get_sampler_stats('depth') | ||
if depth.ndim == 2: | ||
depth = np.max(depth, axis=-1) | ||
|
||
self._check_len(tuning) | ||
self._check_depth(depth[~tuning]) | ||
self._check_accept(accept[~tuning]) | ||
self._check_divergence() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 we could probably do with either
pymc3.exceptions
orpymc3.sampling.exceptions
soon