Skip to content

Commit 7a7405f

Browse files
committed
Refactor ZeroInflatedBinomial
1 parent ce252fa commit 7a7405f

File tree

3 files changed

+96
-42
lines changed

3 files changed

+96
-42
lines changed

pymc3/distributions/discrete.py

Lines changed: 39 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1343,6 +1343,21 @@ def logcdf(value, psi, theta):
13431343
)
13441344

13451345

1346+
class ZeroInflatedBinomialRV(RandomVariable):
1347+
name = "zero_inflated_binomial"
1348+
ndim_supp = 0
1349+
ndims_params = [0, 0, 0]
1350+
dtype = "int64"
1351+
_print_name = ("ZeroInflatedBinom", "\\operatorname{ZeroInflatedBinom}")
1352+
1353+
@classmethod
1354+
def rng_fn(cls, rng, psi, n, p, size):
1355+
return rng.binomial(n=n, p=p, size=size) * (rng.random(size=size) < psi)
1356+
1357+
1358+
zero_inflated_binomial = ZeroInflatedBinomialRV()
1359+
1360+
13461361
class ZeroInflatedBinomial(Discrete):
13471362
R"""
13481363
Zero-inflated Binomial log-likelihood.
@@ -1395,37 +1410,16 @@ class ZeroInflatedBinomial(Discrete):
13951410
13961411
"""
13971412

1398-
def __init__(self, psi, n, p, *args, **kwargs):
1399-
super().__init__(*args, **kwargs)
1400-
self.n = n = at.as_tensor_variable(intX(n))
1401-
self.p = p = at.as_tensor_variable(floatX(p))
1402-
self.psi = psi = at.as_tensor_variable(floatX(psi))
1403-
self.bin = Binomial.dist(n, p)
1404-
self.mode = self.bin.mode
1405-
1406-
def random(self, point=None, size=None):
1407-
r"""
1408-
Draw random values from ZeroInflatedBinomial distribution.
1413+
rv_op = zero_inflated_binomial
14091414

1410-
Parameters
1411-
----------
1412-
point: dict, optional
1413-
Dict of variable values on which random values are to be
1414-
conditioned (uses default point if not specified).
1415-
size: int, optional
1416-
Desired size of random sample (returns one sample if not
1417-
specified).
1418-
1419-
Returns
1420-
-------
1421-
array
1422-
"""
1423-
# n, p, psi = draw_values([self.n, self.p, self.psi], point=point, size=size)
1424-
# g = generate_samples(stats.binom.rvs, n, p, dist_shape=self.shape, size=size)
1425-
# g, psi = broadcast_distribution_samples([g, psi], size=size)
1426-
# return g * (np.random.random(g.shape) < psi)
1415+
@classmethod
1416+
def dist(cls, psi, n, p, *args, **kwargs):
1417+
psi = at.as_tensor_variable(floatX(psi))
1418+
n = at.as_tensor_variable(intX(n))
1419+
p = at.as_tensor_variable(floatX(p))
1420+
return super().dist([psi, n, p], *args, **kwargs)
14271421

1428-
def logp(self, value):
1422+
def logp(value, psi, n, p):
14291423
r"""
14301424
Calculate log-probability of ZeroInflatedBinomial distribution at specified value.
14311425
@@ -1439,19 +1433,24 @@ def logp(self, value):
14391433
-------
14401434
TensorVariable
14411435
"""
1442-
psi = self.psi
1443-
p = self.p
1444-
n = self.n
14451436

14461437
logp_val = at.switch(
14471438
at.gt(value, 0),
1448-
at.log(psi) + self.bin.logp(value),
1439+
at.log(psi) + Binomial.logp(value, n, p),
14491440
logaddexp(at.log1p(-psi), at.log(psi) + n * at.log1p(-p)),
14501441
)
14511442

1452-
return bound(logp_val, 0 <= value, value <= n, 0 <= psi, psi <= 1, 0 <= p, p <= 1)
1443+
return bound(
1444+
logp_val,
1445+
0 <= value,
1446+
value <= n,
1447+
0 <= psi,
1448+
psi <= 1,
1449+
0 <= p,
1450+
p <= 1,
1451+
)
14531452

1454-
def logcdf(self, value):
1453+
def logcdf(value, psi, n, p):
14551454
"""
14561455
Compute the log of the cumulative distribution function for ZeroInflatedBinomial distribution
14571456
at the specified value.
@@ -1465,19 +1464,21 @@ def logcdf(self, value):
14651464
-------
14661465
TensorVariable
14671466
"""
1467+
14681468
# logcdf can only handle scalar values due to limitation in Binomial.logcdf
14691469
if np.ndim(value):
14701470
raise TypeError(
14711471
f"ZeroInflatedBinomial.logcdf expects a scalar value but received a {np.ndim(value)}-dimensional object."
14721472
)
14731473

1474-
psi = self.psi
1475-
14761474
return bound(
1477-
logaddexp(at.log1p(-psi), at.log(psi) + self.bin.logcdf(value)),
1475+
logaddexp(at.log1p(-psi), at.log(psi) + Binomial.logcdf(value, n, p)),
14781476
0 <= value,
1477+
value <= n,
14791478
0 <= psi,
14801479
psi <= 1,
1480+
0 <= p,
1481+
p <= 1,
14811482
)
14821483

14831484

pymc3/tests/test_distributions.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,17 +1682,39 @@ def test_zeroinflatednegativebinomial_logcdf(self):
16821682
n_samples=10,
16831683
)
16841684

1685-
# Too lazy to propagate decimal parameter through the whole chain of deps
1686-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1685+
@pytest.mark.xfail(reason="Test not refactored yet")
16871686
def test_zeroinflatedbinomial_distribution(self):
16881687
self.checkd(
16891688
ZeroInflatedBinomial,
16901689
Nat,
16911690
{"n": NatSmall, "p": Unit, "psi": Unit},
16921691
)
16931692

1694-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1695-
def test_zeroinflatedbinomial_logcdf(self):
1693+
def test_zeroinflatedbinomial(self):
1694+
def logp_fn(value, psi, n, p):
1695+
if value == 0:
1696+
return np.log((1 - psi) * sp.binom.pmf(0, n, p))
1697+
else:
1698+
return np.log(psi * sp.binom.pmf(value, n, p))
1699+
1700+
def logcdf_fn(value, psi, n, p):
1701+
return np.log((1 - psi) + psi * sp.binom.cdf(value, n, p))
1702+
1703+
self.check_logp(
1704+
ZeroInflatedBinomial,
1705+
Nat,
1706+
{"psi": Unit, "n": NatSmall, "p": Unit},
1707+
logp_fn,
1708+
)
1709+
1710+
self.check_logcdf(
1711+
ZeroInflatedBinomial,
1712+
Nat,
1713+
{"psi": Unit, "n": NatSmall, "p": Unit},
1714+
logcdf_fn,
1715+
n_samples=10,
1716+
)
1717+
16961718
self.check_selfconsistency_discrete_logcdf(
16971719
ZeroInflatedBinomial,
16981720
Nat,

pymc3/tests/test_distributions_random.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,37 @@ def seeded_zero_inflated_poisson_rng_fn(self):
954954
]
955955

956956

957+
class TestZeroInflatedBinomial(BaseTestDistribution):
958+
def zero_inflated_poisson_rng_fn(self, size, psi, n, p, binomial_rng_fct, random_rng_fct):
959+
return binomial_rng_fct(n, p, size=size) * (random_rng_fct(size=size) < psi)
960+
961+
def seeded_zero_inflated_binomial_rng_fn(self):
962+
binomial_rng_fct = functools.partial(
963+
getattr(np.random.RandomState, "binomial"), self.get_random_state()
964+
)
965+
966+
random_rng_fct = functools.partial(
967+
getattr(np.random.RandomState, "random"), self.get_random_state()
968+
)
969+
970+
return functools.partial(
971+
self.zero_inflated_poisson_rng_fn,
972+
binomial_rng_fct=binomial_rng_fct,
973+
random_rng_fct=random_rng_fct,
974+
)
975+
976+
pymc_dist = pm.ZeroInflatedBinomial
977+
pymc_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
978+
expected_rv_op_params = {"psi": 0.9, "n": 12, "p": 0.7}
979+
reference_dist_params = {"psi": 0.9, "n": 12, "p": 0.7}
980+
reference_dist = seeded_zero_inflated_binomial_rng_fn
981+
tests_to_run = [
982+
"check_pymc_params_match_rv_op",
983+
"check_pymc_draws_match_reference",
984+
"check_rv_size",
985+
]
986+
987+
957988
class TestOrderedLogistic(BaseTestDistribution):
958989
pymc_dist = pm.OrderedLogistic
959990
pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}

0 commit comments

Comments
 (0)