|
1 | 1 | from collections import namedtuple
|
2 | 2 | import warnings
|
3 | 3 |
|
4 |
| -from ..arraystep import Competence, SamplingError |
| 4 | +from ..arraystep import Competence |
| 5 | +from pymc3.exceptions import SamplingError |
5 | 6 | from .base_hmc import BaseHMC
|
6 | 7 | from pymc3.theanof import floatX
|
7 | 8 | from pymc3.vartypes import continuous_types
|
@@ -170,7 +171,8 @@ def astep(self, q0):
|
170 | 171 | v0 = self.compute_velocity(p0)
|
171 | 172 | start_energy = self.compute_energy(q0, p0)
|
172 | 173 | if not np.isfinite(start_energy):
|
173 |
| - raise ValueError('The initial energy is inf or nan.') |
| 174 | + raise ValueError('Bad initial energy: %s. The model ' |
| 175 | + 'might be misspecified.' % start_energy) |
174 | 176 |
|
175 | 177 | if not self.adapt_step_size:
|
176 | 178 | step_size = self.step_size
|
@@ -310,12 +312,15 @@ def _single_step(self, left, epsilon):
|
310 | 312 | """Perform a leapfrog step and handle error cases."""
|
311 | 313 | try:
|
312 | 314 | right = self.leapfrog(left.q, left.p, left.q_grad, epsilon)
|
313 |
| - except linalg.LinalgError as error: |
| 315 | + except linalg.LinAlgError as err: |
314 | 316 | error_msg = "LinAlgError during leapfrog step."
|
315 |
| - except ValueError as error: |
| 317 | + error = err |
| 318 | + except ValueError as err: |
316 | 319 | # Raised by many scipy.linalg functions
|
317 |
| - if error.args[0].lower() == 'array must not contain infs or nans': |
| 320 | + scipy_msg = "array must not contain infs or nans" |
| 321 | + if len(err.args) > 0 and scipy_msg in err.args[0].lower(): |
318 | 322 | error_msg = "Infs or nans in scipy.linalg during leapfrog step."
|
| 323 | + error = err |
319 | 324 | else:
|
320 | 325 | raise
|
321 | 326 | else:
|
@@ -396,7 +401,7 @@ def _add_divergence(self, tuning, msg, error=None, point=None):
|
396 | 401 | if tuning:
|
397 | 402 | self._divs_tune.append((msg, error, point))
|
398 | 403 | else:
|
399 |
| - self._divs_after_tune((msg, error, point)) |
| 404 | + self._divs_after_tune.append((msg, error, point)) |
400 | 405 | if self._on_error == 'raise':
|
401 | 406 | err = SamplingError('Divergence after tuning: ' + msg)
|
402 | 407 | six.raise_from(err, error)
|
|
0 commit comments