|
8 | 8 | from .dist_math import bound, factln, binomln, betaln, logpow
|
9 | 9 | from .distribution import Discrete, draw_values, generate_samples, reshape_sampled
|
10 | 10 | from pymc3.math import tround
|
11 |
| -from ..math import logsumexp |
| 11 | +from ..math import logaddexp |
12 | 12 |
|
13 | 13 | __all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'DiscreteWeibull',
|
14 | 14 | 'Poisson', 'NegativeBinomial', 'ConstantDist', 'Constant',
|
@@ -634,10 +634,10 @@ def logp(self, value):
|
634 | 634 | theta = self.theta
|
635 | 635 |
|
636 | 636 | logp_val = tt.switch(value > 0,
|
637 |
| - logsumexp(tt.log(psi) + self.pois.logp(value)), |
638 |
| - logsumexp(tt.log((1. - psi) + psi * tt.exp(-theta)))) |
| 637 | + tt.log(psi) + self.pois.logp(value), |
| 638 | + logaddexp(tt.log(-psi), tt.log(psi) - theta)) |
639 | 639 |
|
640 |
| - return bound(logp_val.sum(), |
| 640 | + return bound(logp_val, |
641 | 641 | 0 <= value,
|
642 | 642 | 0 <= psi, psi <= 1,
|
643 | 643 | 0 <= theta)
|
@@ -702,10 +702,10 @@ def logp(self, value):
|
702 | 702 | n = self.n
|
703 | 703 |
|
704 | 704 | logp_val = tt.switch(value > 0,
|
705 |
| - logsumexp(tt.log(psi) + self.bin.logp(value)), |
706 |
| - logsumexp(tt.log((1. - psi) + psi * tt.pow(1 - p, n)))) |
| 705 | + tt.log(psi) + self.bin.logp(value), |
| 706 | + logsumexp(tt.log1p(-psi), tt.log(psi) + n * tt.log1p(-p))) |
707 | 707 |
|
708 |
| - return bound(logp_val.sum(), |
| 708 | + return bound(logp_val, |
709 | 709 | 0 <= value, value <= n,
|
710 | 710 | 0 <= psi, psi <= 1,
|
711 | 711 | 0 <= p, p <= 1)
|
@@ -778,10 +778,10 @@ def logp(self, value):
|
778 | 778 | psi = self.psi
|
779 | 779 |
|
780 | 780 | logp_val = tt.switch(value > 0,
|
781 |
| - logsumexp(tt.log(psi) + self.nb.logp(value)), |
782 |
| - logsumexp(tt.log((1. - psi) + psi * (alpha / (alpha + mu))**alpha))) |
| 781 | + tt.log(psi) + self.nb.logp(value), |
| 782 | + logaddexp(tt.log1p(-psi), tt.log(psi) + alpha * (tt.log(alpha) - tt.log(alpha + mu)))) |
783 | 783 |
|
784 |
| - return bound(logp_val.sum(), |
| 784 | + return bound(logp_val, |
785 | 785 | 0 <= value,
|
786 | 786 | 0 <= psi, psi <= 1,
|
787 | 787 | mu > 0, alpha > 0)
|
|
0 commit comments