Skip to content

Commit 9a830d7

Browse files
committed
Implement review suggestions
1 parent b4bade2 commit 9a830d7

File tree

4 files changed

+18
-12
lines changed

4 files changed

+18
-12
lines changed

pymc3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .tuning import *
1616
from .variational import *
1717
from .vartypes import *
18+
from .exceptions import *
1819
from . import sampling
1920

2021
from .debug import *

pymc3/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
__all__ = ['SamplingError']
2+
3+
4+
class SamplingError(RuntimeError):
5+
pass

pymc3/step_methods/arraystep.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,7 @@
77
from enum import IntEnum, unique
88

99
__all__ = [
10-
'ArrayStep', 'ArrayStepShared', 'metrop_select',
11-
'Competence', 'SamplingError']
12-
13-
14-
class SamplingError(RuntimeError):
15-
pass
10+
'ArrayStep', 'ArrayStepShared', 'metrop_select', 'Competence']
1611

1712

1813
@unique

pymc3/step_methods/hmc/nuts.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from collections import namedtuple
22
import warnings
33

4-
from ..arraystep import Competence, SamplingError
4+
from ..arraystep import Competence
5+
from pymc3.exceptions import SamplingError
56
from .base_hmc import BaseHMC
67
from pymc3.theanof import floatX
78
from pymc3.vartypes import continuous_types
@@ -170,7 +171,8 @@ def astep(self, q0):
170171
v0 = self.compute_velocity(p0)
171172
start_energy = self.compute_energy(q0, p0)
172173
if not np.isfinite(start_energy):
173-
raise ValueError('The initial energy is inf or nan.')
174+
raise ValueError('Bad initial energy: %s. The model '
175+
'might be misspecified.' % start_energy)
174176

175177
if not self.adapt_step_size:
176178
step_size = self.step_size
@@ -310,12 +312,15 @@ def _single_step(self, left, epsilon):
310312
"""Perform a leapfrog step and handle error cases."""
311313
try:
312314
right = self.leapfrog(left.q, left.p, left.q_grad, epsilon)
313-
except linalg.LinalgError as error:
315+
except linalg.LinAlgError as err:
314316
error_msg = "LinAlgError during leapfrog step."
315-
except ValueError as error:
317+
error = err
318+
except ValueError as err:
316319
# Raised by many scipy.linalg functions
317-
if error.args[0].lower() == 'array must not contain infs or nans':
320+
scipy_msg = "array must not contain infs or nans"
321+
if len(err.args) > 0 and scipy_msg in err.args[0].lower():
318322
error_msg = "Infs or nans in scipy.linalg during leapfrog step."
323+
error = err
319324
else:
320325
raise
321326
else:
@@ -396,7 +401,7 @@ def _add_divergence(self, tuning, msg, error=None, point=None):
396401
if tuning:
397402
self._divs_tune.append((msg, error, point))
398403
else:
399-
self._divs_after_tune((msg, error, point))
404+
self._divs_after_tune.append((msg, error, point))
400405
if self._on_error == 'raise':
401406
err = SamplingError('Divergence after tuning: ' + msg)
402407
six.raise_from(err, error)

0 commit comments

Comments
 (0)