Skip to content

Commit c3fcc96

Browse files
committed
Fix cutoff value in log1mexp and redundant reimplementation in Exponential.logcdf()
1 parent 3cfee77 commit c3fcc96

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

pymc3/distributions/continuous.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
)
4646
from pymc3.distributions.distribution import Continuous, draw_values, generate_samples
4747
from pymc3.distributions.special import log_i0
48-
from pymc3.math import invlogit, logdiffexp, logit
48+
from pymc3.math import invlogit, log1mexp, logdiffexp, logit
4949
from pymc3.theanof import floatX
5050

5151
__all__ = [
@@ -1513,12 +1513,6 @@ def logcdf(self, value):
15131513
Compute the log of cumulative distribution function for the Exponential distribution
15141514
at the specified value.
15151515
1516-
References
1517-
----------
1518-
.. [Machler2012] Martin Mächler (2012).
1519-
"Accurately computing :math:`\log(1-\exp(-\mid a \mid))` Assessed by the Rmpfr
1520-
package"
1521-
15221516
Parameters
15231517
----------
15241518
value: numeric
@@ -1533,9 +1527,9 @@ def logcdf(self, value):
15331527
lam = self.lam
15341528
a = lam * value
15351529
return tt.switch(
1536-
tt.le(value, 0.0),
1530+
tt.or_(tt.le(value, 0.0), tt.le(lam, 0)),
15371531
-np.inf,
1538-
tt.switch(tt.le(a, tt.log(2.0)), tt.log(-tt.expm1(-a)), tt.log1p(-tt.exp(-a))),
1532+
log1mexp(a),
15391533
)
15401534

15411535

pymc3/math.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def log1mexp(x):
226226
For details, see
227227
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
228228
"""
229-
return tt.switch(tt.lt(x, 0.683), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x)))
229+
return tt.switch(tt.lt(x, 0.693147), tt.log(-tt.expm1(-x)), tt.log1p(-tt.exp(-x)))
230230

231231

232232
def log1mexp_numpy(x):
@@ -235,7 +235,7 @@ def log1mexp_numpy(x):
235235
For details, see
236236
https://cran.r-project.org/web/packages/Rmpfr/vignettes/log1mexp-note.pdf
237237
"""
238-
return np.where(x < 0.683, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))
238+
return np.where(x < 0.693147, np.log(-np.expm1(-x)), np.log1p(-np.exp(-x)))
239239

240240

241241
def flatten_list(tensors):

0 commit comments

Comments
 (0)