Skip to content

Commit 0974ecc

Browse files
authored
Merge pull request #2251 from pymc-devs/zero_inflated_binomial
Zero inflated binomial
2 parents fc20e31 + 5184768 commit 0974ecc

File tree

6 files changed

+120
-18
lines changed

6 files changed

+120
-18
lines changed

pymc3/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from .external import *
77
from .glm import *
88
from . import gp
9-
from .math import logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
9+
from .math import logaddexp, logsumexp, logit, invlogit, expand_packed_triangular, probit, invprobit
1010
from .model import *
1111
from .stats import *
1212
from .sampling import *

pymc3/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .discrete import Constant
3838
from .discrete import ZeroInflatedPoisson
3939
from .discrete import ZeroInflatedNegativeBinomial
40+
from .discrete import ZeroInflatedBinomial
4041
from .discrete import DiscreteUniform
4142
from .discrete import Geometric
4243
from .discrete import Categorical
@@ -106,6 +107,7 @@
106107
'Constant',
107108
'ZeroInflatedPoisson',
108109
'ZeroInflatedNegativeBinomial',
110+
'ZeroInflatedBinomial',
109111
'DiscreteUniform',
110112
'Geometric',
111113
'Categorical',

pymc3/distributions/discrete.py

Lines changed: 104 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,11 @@
88
from .dist_math import bound, factln, binomln, betaln, logpow
99
from .distribution import Discrete, draw_values, generate_samples, reshape_sampled
1010
from pymc3.math import tround
11+
from ..math import logaddexp
1112

1213
__all__ = ['Binomial', 'BetaBinomial', 'Bernoulli', 'DiscreteWeibull',
1314
'Poisson', 'NegativeBinomial', 'ConstantDist', 'Constant',
14-
'ZeroInflatedPoisson', 'ZeroInflatedNegativeBinomial',
15+
'ZeroInflatedPoisson', 'ZeroInflatedBinomial', 'ZeroInflatedNegativeBinomial',
1516
'DiscreteUniform', 'Geometric', 'Categorical']
1617

1718

@@ -593,7 +594,7 @@ class ZeroInflatedPoisson(Discrete):
593594
594595
.. math::
595596
596-
f(x \mid \theta, \psi) = \left\{ \begin{array}{l}
597+
f(x \mid \psi, \theta) = \left\{ \begin{array}{l}
597598
(1-\psi) + \psi e^{-\theta}, \text{if } x = 0 \\
598599
\psi \frac{e^{-\theta}\theta^x}{x!}, \text{if } x=1,2,3,\ldots
599600
\end{array} \right.
@@ -606,15 +607,14 @@ class ZeroInflatedPoisson(Discrete):
606607
607608
Parameters
608609
----------
610+
psi : float
611+
Expected proportion of Poisson variates (0 < psi < 1)
609612
theta : float
610613
Expected number of occurrences during the given interval
611614
(theta >= 0).
612-
psi : float
613-
Expected proportion of Poisson variates (0 < psi < 1)
614-
615615
"""
616616

617-
def __init__(self, theta, psi, *args, **kwargs):
617+
def __init__(self, psi, theta, *args, **kwargs):
618618
super(ZeroInflatedPoisson, self).__init__(*args, **kwargs)
619619
self.theta = theta = tt.as_tensor_variable(theta)
620620
self.psi = psi = tt.as_tensor_variable(psi)
@@ -630,9 +630,17 @@ def random(self, point=None, size=None, repeat=None):
630630
return reshape_sampled(sampled, size, self.shape)
631631

632632
def logp(self, value):
633-
return tt.switch(value > 0,
634-
tt.log(self.psi) + self.pois.logp(value),
635-
tt.log((1. - self.psi) + self.psi * tt.exp(-self.theta)))
633+
psi = self.psi
634+
theta = self.theta
635+
636+
logp_val = tt.switch(value > 0,
637+
tt.log(psi) + self.pois.logp(value),
638+
logaddexp(tt.log1p(-psi), tt.log(psi) - theta))
639+
640+
return bound(logp_val,
641+
0 <= value,
642+
0 <= psi, psi <= 1,
643+
0 <= theta)
636644

637645
def _repr_latex_(self, name=None, dist=None):
638646
if dist is None:
@@ -644,6 +652,76 @@ def _repr_latex_(self, name=None, dist=None):
644652
get_variable_name(psi))
645653

646654

655+
class ZeroInflatedBinomial(Discrete):
656+
R"""
657+
Zero-inflated Binomial log-likelihood.
658+
659+
.. math::
660+
661+
f(x \mid \psi, n, p) = \left\{ \begin{array}{l}
662+
(1-\psi) + \psi (1-p)^{n}, \text{if } x = 0 \\
663+
\psi {n \choose x} p^x (1-p)^{n-x}, \text{if } x=1,2,3,\ldots,n
664+
\end{array} \right.
665+
666+
======== ==========================
667+
Support :math:`x \in \mathbb{N}_0`
668+
Mean :math:`(1 - \psi) n p`
669+
Variance :math:`(1-\psi) n p [1 - p(1 - \psi n)].`
670+
======== ==========================
671+
672+
Parameters
673+
----------
674+
psi : float
675+
Expected proportion of Binomial variates (0 < psi < 1)
676+
n : int
677+
Number of Bernoulli trials (n >= 0).
678+
p : float
679+
Probability of success in each trial (0 < p < 1).
680+
681+
"""
682+
683+
def __init__(self, psi, n, p, *args, **kwargs):
684+
super(ZeroInflatedBinomial, self).__init__(*args, **kwargs)
685+
self.n = n = tt.as_tensor_variable(n)
686+
self.p = p = tt.as_tensor_variable(p)
687+
self.psi = psi = tt.as_tensor_variable(psi)
688+
self.bin = Binomial.dist(n, p)
689+
self.mode = self.bin.mode
690+
691+
def random(self, point=None, size=None, repeat=None):
692+
n, p, psi = draw_values([self.n, self.p, self.psi], point=point)
693+
g = generate_samples(stats.binom.rvs, n, p,
694+
dist_shape=self.shape,
695+
size=size)
696+
sampled = g * (np.random.random(np.squeeze(g.shape)) < psi)
697+
return reshape_sampled(sampled, size, self.shape)
698+
699+
def logp(self, value):
700+
psi = self.psi
701+
p = self.p
702+
n = self.n
703+
704+
logp_val = tt.switch(value > 0,
705+
tt.log(psi) + self.bin.logp(value),
706+
logaddexp(tt.log1p(-psi), tt.log(psi) + n * tt.log1p(-p)))
707+
708+
return bound(logp_val,
709+
0 <= value, value <= n,
710+
0 <= psi, psi <= 1,
711+
0 <= p, p <= 1)
712+
713+
def _repr_latex_(self, name=None, dist=None):
714+
if dist is None:
715+
dist = self
716+
n = dist.n
717+
p = dist.p
718+
psi = dist.psi
719+
return r'${} \sim \text{{ZeroInflatedBinomial}}(\mathit{{n}}={}, \mathit{{p}}={}, \mathit{{psi}}={})$'.format(name,
720+
get_variable_name(n),
721+
get_variable_name(p),
722+
get_variable_name(psi))
723+
724+
647725
class ZeroInflatedNegativeBinomial(Discrete):
648726
R"""
649727
Zero-Inflated Negative binomial log-likelihood.
@@ -654,7 +732,7 @@ class ZeroInflatedNegativeBinomial(Discrete):
654732
655733
.. math::
656734
657-
f(x \mid \mu, \alpha, \psi) = \left\{ \begin{array}{l}
735+
f(x \mid \psi, \mu, \alpha) = \left\{ \begin{array}{l}
658736
(1-\psi) + \psi \left (\frac{\alpha}{\alpha+\mu} \right) ^\alpha, \text{if } x = 0 \\
659737
\psi \frac{\Gamma(x+\alpha)}{x! \Gamma(\alpha)} \left (\frac{\alpha}{\mu+\alpha} \right)^\alpha \left( \frac{\mu}{\mu+\alpha} \right)^x, \text{if } x=1,2,3,\ldots
660738
\end{array} \right.
@@ -667,15 +745,16 @@ class ZeroInflatedNegativeBinomial(Discrete):
667745
668746
Parameters
669747
----------
748+
psi : float
749+
Expected proportion of NegativeBinomial variates (0 < psi < 1)
670750
mu : float
671751
Poission distribution parameter (mu > 0).
672752
alpha : float
673753
Gamma distribution parameter (alpha > 0).
674-
psi : float
675-
Expected proportion of NegativeBinomial variates (0 < psi < 1)
754+
676755
"""
677756

678-
def __init__(self, mu, alpha, psi, *args, **kwargs):
757+
def __init__(self, psi, mu, alpha, *args, **kwargs):
679758
super(ZeroInflatedNegativeBinomial, self).__init__(*args, **kwargs)
680759
self.mu = mu = tt.as_tensor_variable(mu)
681760
self.alpha = alpha = tt.as_tensor_variable(alpha)
@@ -694,9 +773,18 @@ def random(self, point=None, size=None, repeat=None):
694773
return reshape_sampled(sampled, size, self.shape)
695774

696775
def logp(self, value):
697-
return tt.switch(value > 0,
698-
tt.log(self.psi) + self.nb.logp(value),
699-
tt.log((1. - self.psi) + self.psi * (self.alpha / (self.alpha + self.mu))**self.alpha))
776+
alpha = self.alpha
777+
mu = self.mu
778+
psi = self.psi
779+
780+
logp_val = tt.switch(value > 0,
781+
tt.log(psi) + self.nb.logp(value),
782+
logaddexp(tt.log1p(-psi), tt.log(psi) + alpha * (tt.log(alpha) - tt.log(alpha + mu))))
783+
784+
return bound(logp_val,
785+
0 <= value,
786+
0 <= psi, psi <= 1,
787+
mu > 0, alpha > 0)
700788

701789
def _repr_latex_(self, name=None, dist=None):
702790
if dist is None:

pymc3/math.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,11 @@ def logsumexp(x, axis=None):
3131
x_max = tt.max(x, axis=axis, keepdims=True)
3232
return tt.log(tt.sum(tt.exp(x - x_max), axis=axis, keepdims=True)) + x_max
3333

34+
def logaddexp(a, b):
35+
diff = b - a
36+
return tt.switch(diff > 0,
37+
b + tt.log1p(tt.exp(-diff)),
38+
a + tt.log1p(tt.exp(diff)))
3439

3540
def invlogit(x, eps=sys.float_info.epsilon):
3641
return (1. - 2. * eps) / (1. + tt.exp(-x)) + eps

pymc3/tests/test_distributions.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
1515
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
1616
Bound, Uniform, Triangular, Binomial, SkewNormal, DiscreteWeibull, Gumbel,
17-
Interpolated)
17+
Interpolated, ZeroInflatedBinomial)
1818
from ..distributions import continuous
1919
from pymc3.theanof import floatX
2020
from numpy import array, inf, log, exp
@@ -591,6 +591,10 @@ def test_zeroinflatednegativebinomial(self):
591591
self.checkd(ZeroInflatedNegativeBinomial, Nat,
592592
{'mu': Rplusbig, 'alpha': Rplusbig, 'psi': Unit})
593593

594+
def test_zeroinflatedbinomial(self):
595+
self.checkd(ZeroInflatedBinomial, Nat,
596+
{'n': NatSmall, 'p': Unit, 'psi': Unit})
597+
594598
@pytest.mark.parametrize('n', [1, 2, 3])
595599
def test_mvnormal(self, n):
596600
self.pymc3_matches_scipy(MvNormal, RealMatrix(5, n),

pymc3/tests/test_distributions_random.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,9 @@ class TestZeroInflatedNegativeBinomial(BaseTestCases.BaseTestCase):
338338
distribution = pm.ZeroInflatedNegativeBinomial
339339
params = {'mu': 1., 'alpha': 1., 'psi': 0.3}
340340

341+
class TestZeroInflatedBinomial(BaseTestCases.BaseTestCase):
342+
distribution = pm.ZeroInflatedBinomial
343+
params = {'n': 10, 'p': 0.6, 'psi': 0.3}
341344

342345
class TestDiscreteUniform(BaseTestCases.BaseTestCase):
343346
distribution = pm.DiscreteUniform

0 commit comments

Comments
 (0)