Skip to content

Commit b7daa3e

Browse files
committed
Fixing logp methods of zero-inflated distributions
1 parent beb489d commit b7daa3e

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

pymc3/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .external import *
77
from .glm import *
88
from . import gp
9-
from .math import logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
9+
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
1010
from .model import *
1111
from .stats import *
1212
from .sampling import *

pymc3/distributions/discrete.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .dist_math import bound, factln, binomln, betaln, logpow
99
from .distribution import Discrete, draw_values, generate_samples, reshape_sampled
1010
from pymc3.math import tround
11-
from ..math import logsumexp
11+
from ..math import logaddexp
1212

1313
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'DiscreteWeibull',
1414
'Poisson', 'NegativeBinomial', 'ConstantDist', 'Constant',
@@ -634,10 +634,10 @@ def logp(self, value):
634634
theta = self.theta
635635

636636
logp_val = tt.switch(value > 0,
637-
logsumexp(tt.log(psi) + self.pois.logp(value)),
638-
logsumexp(tt.log((1. - psi) + psi * tt.exp(-theta))))
637+
tt.log(psi) + self.pois.logp(value),
638+
logaddexp(tt.log(-psi), tt.log(psi) - theta))
639639

640-
return bound(logp_val.sum(),
640+
return bound(logp_val,
641641
0 <= value,
642642
0 <= psi, psi <= 1,
643643
0 <= theta)
@@ -702,10 +702,10 @@ def logp(self, value):
702702
n = self.n
703703

704704
logp_val = tt.switch(value > 0,
705-
logsumexp(tt.log(psi) + self.bin.logp(value)),
706-
logsumexp(tt.log((1. - psi) + psi * tt.pow(1 - p, n))))
705+
tt.log(psi) + self.bin.logp(value),
706+
logsumexp(tt.log1p(-psi), tt.log(psi) + n * tt.log1p(-p)))
707707

708-
return bound(logp_val.sum(),
708+
return bound(logp_val,
709709
0 <= value, value <= n,
710710
0 <= psi, psi <= 1,
711711
0 <= p, p <= 1)
@@ -778,10 +778,10 @@ def logp(self, value):
778778
psi = self.psi
779779

780780
logp_val = tt.switch(value > 0,
781-
logsumexp(tt.log(psi) + self.nb.logp(value)),
782-
logsumexp(tt.log((1. - psi) + psi * (alpha / (alpha + mu))**alpha)))
781+
tt.log(psi) + self.nb.logp(value),
782+
logaddexp(tt.log1p(-psi), tt.log(psi) + alpha * (tt.log(alpha) - tt.log(alpha + mu))))
783783

784-
return bound(logp_val.sum(),
784+
return bound(logp_val,
785785
0 <= value,
786786
0 <= psi, psi <= 1,
787787
mu > 0, alpha > 0)

0 commit comments

Comments
 (0)