Skip to content

Small NUTS refactor to improve error handling #2215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 24, 2017

Conversation

aseyboldt
Copy link
Member

This does a couple of things:

  • Add an on_error option to NUTS, that controls how it should handle divergences.
  • Move the warnings about bad traces into a NutsReport object
  • Catch Linalg errors in nuts and convert them to a divergence.

Some tests are still missing.
CC @AustinRochford

@aseyboldt aseyboldt added the WIP label May 23, 2017
Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this! I would probably change the warning messages to suggest solutions in addition to describing the problem, but would be happy to do that in a followup PR so we can quibble about the exact wording.


def astep(self, q0):
p0 = self.potential.random()
v0 = self.compute_velocity(p0)
start_energy = self.compute_energy(q0, p0)
if not np.isfinite(start_energy):
raise ValueError('The initial energy is inf or nan.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just print repr of start_energy here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's better, yes.

'Competence', 'SamplingError']


class SamplingError(RuntimeError):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 we could probably do with either pymc3.exceptions or pymc3.sampling.exceptions soon

@@ -279,7 +233,7 @@ def check_trace(self, strace):
"left, right, p_sum, proposal, log_size, accept_sum, n_proposals")


class Tree(object):
class _Tree(object):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

error_msg = "LinAlgError during leapfrog step."
except ValueError as error:
# Raised by many scipy.linalg functions
if error.args[0].lower() == 'array must not contain infs or nans':
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a little dangerous: it is brittle to changes in scipy.linalg, and will throw an IndexError if there is no message. Can you check directly if the array contains infs or nans?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't like it either, but couldn't think of a better solution. We don't have access to the actual array that led to the problem, it could be hidden somewhere in the theano graph. But I'll add a check that error.args is long enough.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I'll add a test for this, then we should at least notice if this changes.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks -- would be super frustrating/confusing if this threw an IndexError :)

@aseyboldt
Copy link
Member Author

@ColCarroll If you have suggestions for better error messages, I think we can do that in this PR.

Copy link
Member

@ColCarroll ColCarroll left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seemed like a good quiz of diagnosing problems -- not sure of any of my answers!

# Raised by many scipy.linalg functions
scipy_msg = "array must not contain infs or nans"
if len(err.args) > 0 and scipy_msg in err.args[0].lower():
error_msg = "Infs or nans in scipy.linalg during leapfrog step."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this would probably have to be a discontinuity in the pdf?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A got a lot of those when computing gradients of mvnormal with numerically indefinite covariance matrices. I can't think of any good advice if this happens, other than check anything that involves scipy.linalg.

tree = Subtree(right, right, right.p, proposal, log_size, p_accept, 1)
return tree, False, False
else:
error_msg = "Bad energy after leapfrog step."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems like the advice after this is to make the step_scale smaller (and the message should report the current value)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The step size should be adapted already when you see this message. Increasing target_accept is I think the right thing (this will change the step size adaptation to use smaller step sizes). But otherwise I agree.

else:
self._divs_after_tune.append((msg, error, point))
if self._on_error == 'raise':
err = SamplingError('Divergence after tuning: ' + msg)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this suggests reparametrizing or smaller steps?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, but larger target_accept again instead of step size.

@aseyboldt
Copy link
Member Author

I think this is ready to be merged. The coverage dropped a bit after the last commit, but I can't think of a good reason for this. I guess that's some issue with coveralls again...

@aseyboldt aseyboldt removed the WIP label May 24, 2017
@ColCarroll ColCarroll merged commit c306452 into pymc-devs:master May 24, 2017
@aseyboldt aseyboldt deleted the nuts-report branch May 24, 2017 21:15
@AustinRochford
Copy link
Member

Thanks for this @aseyboldt!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants