Skip to content

Commit ce252fa

Browse files
committed
Refactor ZeroInflatedPoisson
1 parent 5c723b3 commit ce252fa

File tree

3 files changed

+89
-45
lines changed

3 files changed

+89
-45
lines changed

pymc3/distributions/discrete.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1215,6 +1215,21 @@ def logp(value, c):
12151215
)
12161216

12171217

1218+
class ZeroInflatedPoissonRV(RandomVariable):
1219+
name = "zero_inflated_poisson"
1220+
ndim_supp = 0
1221+
ndims_params = [0, 0]
1222+
dtype = "int64"
1223+
_print_name = ("ZeroInflatedPois", "\\operatorname{ZeroInflatedPois}")
1224+
1225+
@classmethod
1226+
def rng_fn(cls, rng, psi, lam, size):
1227+
return rng.poisson(lam, size=size) * (rng.random(size=size) < psi)
1228+
1229+
1230+
zero_inflated_poisson = ZeroInflatedPoissonRV()
1231+
1232+
12181233
class ZeroInflatedPoisson(Discrete):
12191234
R"""
12201235
Zero-inflated Poisson log-likelihood.
@@ -1266,36 +1281,15 @@ class ZeroInflatedPoisson(Discrete):
12661281
(theta >= 0).
12671282
"""
12681283

1269-
def __init__(self, psi, theta, *args, **kwargs):
1270-
super().__init__(*args, **kwargs)
1271-
self.theta = theta = at.as_tensor_variable(floatX(theta))
1272-
self.psi = at.as_tensor_variable(floatX(psi))
1273-
self.pois = Poisson.dist(theta)
1274-
self.mode = self.pois.mode
1275-
1276-
def random(self, point=None, size=None):
1277-
r"""
1278-
Draw random values from ZeroInflatedPoisson distribution.
1279-
1280-
Parameters
1281-
----------
1282-
point: dict, optional
1283-
Dict of variable values on which random values are to be
1284-
conditioned (uses default point if not specified).
1285-
size: int, optional
1286-
Desired size of random sample (returns one sample if not
1287-
specified).
1284+
rv_op = zero_inflated_poisson
12881285

1289-
Returns
1290-
-------
1291-
array
1292-
"""
1293-
# theta, psi = draw_values([self.theta, self.psi], point=point, size=size)
1294-
# g = generate_samples(stats.poisson.rvs, theta, dist_shape=self.shape, size=size)
1295-
# g, psi = broadcast_distribution_samples([g, psi], size=size)
1296-
# return g * (np.random.random(g.shape) < psi)
1286+
@classmethod
1287+
def dist(cls, psi, theta, *args, **kwargs):
1288+
psi = at.as_tensor_variable(floatX(psi))
1289+
theta = at.as_tensor_variable(floatX(theta))
1290+
return super().dist([psi, theta], *args, **kwargs)
12971291

1298-
def logp(self, value):
1292+
def logp(value, psi, theta):
12991293
r"""
13001294
Calculate log-probability of ZeroInflatedPoisson distribution at specified value.
13011295
@@ -1309,18 +1303,22 @@ def logp(self, value):
13091303
-------
13101304
TensorVariable
13111305
"""
1312-
psi = self.psi
1313-
theta = self.theta
13141306

13151307
logp_val = at.switch(
13161308
at.gt(value, 0),
1317-
at.log(psi) + self.pois.logp(value),
1309+
at.log(psi) + Poisson.logp(value, theta),
13181310
logaddexp(at.log1p(-psi), at.log(psi) - theta),
13191311
)
13201312

1321-
return bound(logp_val, 0 <= value, 0 <= psi, psi <= 1, 0 <= theta)
1313+
return bound(
1314+
logp_val,
1315+
0 <= value,
1316+
0 <= psi,
1317+
psi <= 1,
1318+
0 <= theta,
1319+
)
13221320

1323-
def logcdf(self, value):
1321+
def logcdf(value, psi, theta):
13241322
"""
13251323
Compute the log of the cumulative distribution function for ZeroInflatedPoisson distribution
13261324
at the specified value.
@@ -1335,13 +1333,13 @@ def logcdf(self, value):
13351333
-------
13361334
TensorVariable
13371335
"""
1338-
psi = self.psi
13391336

13401337
return bound(
1341-
logaddexp(at.log1p(-psi), at.log(psi) + self.pois.logcdf(value)),
1338+
logaddexp(at.log1p(-psi), at.log(psi) + Poisson.logcdf(value, theta)),
13421339
0 <= value,
13431340
0 <= psi,
13441341
psi <= 1,
1342+
0 <= theta,
13451343
)
13461344

13471345

pymc3/tests/test_distributions.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,8 +1618,7 @@ def test_bound_poisson(self):
16181618
def test_constantdist(self):
16191619
self.check_logp(Constant, I, {"c": I}, lambda value, c: np.log(c == value))
16201620

1621-
# Too lazy to propagate decimal parameter through the whole chain of deps
1622-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1621+
@pytest.mark.xfail(reason="Test has not been refactored")
16231622
@pytest.mark.xfail(
16241623
condition=(aesara.config.floatX == "float32"),
16251624
reason="Fails on float32 due to inf issues",
@@ -1631,8 +1630,30 @@ def test_zeroinflatedpoisson_distribution(self):
16311630
{"theta": Rplus, "psi": Unit},
16321631
)
16331632

1634-
@pytest.mark.xfail(reason="Distribution not refactored yet")
1635-
def test_zeroinflatedpoisson_logcdf(self):
1633+
def test_zeroinflatedpoisson(self):
1634+
def logp_fn(value, psi, theta):
1635+
if value == 0:
1636+
return np.log((1 - psi) * sp.poisson.pmf(0, theta))
1637+
else:
1638+
return np.log(psi * sp.poisson.pmf(value, theta))
1639+
1640+
def logcdf_fn(value, psi, theta):
1641+
return np.log((1 - psi) + psi * sp.poisson.cdf(value, theta))
1642+
1643+
self.check_logp(
1644+
ZeroInflatedPoisson,
1645+
Nat,
1646+
{"psi": Unit, "theta": Rplus},
1647+
logp_fn,
1648+
)
1649+
1650+
self.check_logcdf(
1651+
ZeroInflatedPoisson,
1652+
Nat,
1653+
{"psi": Unit, "theta": Rplus},
1654+
logcdf_fn,
1655+
)
1656+
16361657
self.check_selfconsistency_discrete_logcdf(
16371658
ZeroInflatedPoisson,
16381659
Nat,

pymc3/tests/test_distributions_random.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -313,12 +313,6 @@ class TestLogitNormal(BaseTestCases.BaseTestCase):
313313
params = {"mu": 0.0, "sigma": 1.0}
314314

315315

316-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
317-
class TestZeroInflatedPoisson(BaseTestCases.BaseTestCase):
318-
distribution = pm.ZeroInflatedPoisson
319-
params = {"theta": 1.0, "psi": 0.3}
320-
321-
322316
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
323317
class TestZeroInflatedNegativeBinomial(BaseTestCases.BaseTestCase):
324318
distribution = pm.ZeroInflatedNegativeBinomial
@@ -929,6 +923,37 @@ def constant_rng_fn(self, size, c):
929923
]
930924

931925

926+
class TestZeroInflatedPoisson(BaseTestDistribution):
927+
def zero_inflated_poisson_rng_fn(self, size, psi, theta, poisson_rng_fct, random_rng_fct):
928+
return poisson_rng_fct(theta, size=size) * (random_rng_fct(size=size) < psi)
929+
930+
def seeded_zero_inflated_poisson_rng_fn(self):
931+
poisson_rng_fct = functools.partial(
932+
getattr(np.random.RandomState, "poisson"), self.get_random_state()
933+
)
934+
935+
random_rng_fct = functools.partial(
936+
getattr(np.random.RandomState, "random"), self.get_random_state()
937+
)
938+
939+
return functools.partial(
940+
self.zero_inflated_poisson_rng_fn,
941+
poisson_rng_fct=poisson_rng_fct,
942+
random_rng_fct=random_rng_fct,
943+
)
944+
945+
pymc_dist = pm.ZeroInflatedPoisson
946+
pymc_dist_params = {"psi": 0.9, "theta": 4.0}
947+
expected_rv_op_params = {"psi": 0.9, "theta": 4.0}
948+
reference_dist_params = {"psi": 0.9, "theta": 4.0}
949+
reference_dist = seeded_zero_inflated_poisson_rng_fn
950+
tests_to_run = [
951+
"check_pymc_params_match_rv_op",
952+
"check_pymc_draws_match_reference",
953+
"check_rv_size",
954+
]
955+
956+
932957
class TestOrderedLogistic(BaseTestDistribution):
933958
pymc_dist = pm.OrderedLogistic
934959
pymc_dist_params = {"eta": 0, "cutpoints": np.array([-2, 0, 2])}

0 commit comments

Comments
 (0)