-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Small NUTS refactor to improve error handling #2215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this! I would probably change the warning messages to suggest solutions in addition to describing the problem, but would be happy to do that in a followup PR so we can quibble about the exact wording.
pymc3/step_methods/hmc/nuts.py
Outdated
|
||
def astep(self, q0): | ||
p0 = self.potential.random() | ||
v0 = self.compute_velocity(p0) | ||
start_energy = self.compute_energy(q0, p0) | ||
if not np.isfinite(start_energy): | ||
raise ValueError('The initial energy is inf or nan.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just print repr
of start_energy
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that's better, yes.
pymc3/step_methods/arraystep.py
Outdated
'Competence', 'SamplingError'] | ||
|
||
|
||
class SamplingError(RuntimeError): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍 we could probably do with either pymc3.exceptions
or pymc3.sampling.exceptions
soon
@@ -279,7 +233,7 @@ def check_trace(self, strace): | |||
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals") | |||
|
|||
|
|||
class Tree(object): | |||
class _Tree(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
pymc3/step_methods/hmc/nuts.py
Outdated
error_msg = "LinAlgError during leapfrog step." | ||
except ValueError as error: | ||
# Raised by many scipy.linalg functions | ||
if error.args[0].lower() == 'array must not contain infs or nans': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a little dangerous: it is brittle to changes in scipy.linalg
, and will throw an IndexError
if there is no message. Can you check directly if the array contains infs or nans?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like it either, but couldn't think of a better solution. We don't have access to the actual array that led to the problem, it could be hidden somewhere in the theano graph. But I'll add a check that error.args
is long enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I'll add a test for this, then we should at least notice if this changes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks -- would be super frustrating/confusing if this threw an IndexError
:)
@ColCarroll If you have suggestions for better error messages, I think we can do that in this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seemed like a good quiz of diagnosing problems -- not sure of any of my answers!
# Raised by many scipy.linalg functions | ||
scipy_msg = "array must not contain infs or nans" | ||
if len(err.args) > 0 and scipy_msg in err.args[0].lower(): | ||
error_msg = "Infs or nans in scipy.linalg during leapfrog step." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this would probably have to be a discontinuity in the pdf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A got a lot of those when computing gradients of mvnormal with numerically indefinite covariance matrices. I can't think of any good advice if this happens, other than check anything that involves scipy.linalg
.
pymc3/step_methods/hmc/nuts.py
Outdated
tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1) | ||
return tree, False, False | ||
else: | ||
error_msg = "Bad energy after leapfrog step." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems like the advice after this is to make the step_scale
smaller (and the message should report the current value)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The step size should be adapted already when you see this message. Increasing target_accept
is I think the right thing (this will change the step size adaptation to use smaller step sizes). But otherwise I agree.
pymc3/step_methods/hmc/nuts.py
Outdated
else: | ||
self._divs_after_tune.append((msg, error, point)) | ||
if self._on_error == 'raise': | ||
err = SamplingError('Divergence after tuning: ' + msg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this suggests reparametrizing or smaller steps?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yep, but larger target_accept
again instead of step size.
I think this is ready to be merged. The coverage dropped a bit after the last commit, but I can't think of a good reason for this. I guess that's some issue with coveralls again... |
Thanks for this @aseyboldt! |
This does a couple of things:
on_error
option to NUTS, that controls how it should handle divergences.NutsReport
objectSome tests are still missing.
CC @AustinRochford