Skip to content

Commit 03e1df5

Browse files
committed
Refactor DiscreteUniform
1 parent fa0e930 commit 03e1df5

File tree

3 files changed

+44
-57
lines changed

3 files changed

+44
-57
lines changed

pymc3/distributions/discrete.py

Lines changed: 23 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,21 @@ def logcdf(value, good, bad, n):
970970
)
971971

972972

973+
class DiscreteUniformRV(RandomVariable):
974+
name = "discrete_uniform"
975+
ndim_supp = 0
976+
ndims_params = [0, 0]
977+
dtype = "int64"
978+
_print_name = ("DiscreteUniform", "\\operatorname{DiscreteUniform}")
979+
980+
@classmethod
981+
def rng_fn(cls, rng, lower, upper, size=None):
982+
return stats.randint.rvs(lower, upper + 1, size=size, random_state=rng)
983+
984+
985+
discrete_uniform = DiscreteUniformRV()
986+
987+
973988
class DiscreteUniform(Discrete):
974989
R"""
975990
Discrete uniform distribution.
@@ -1010,39 +1025,15 @@ class DiscreteUniform(Discrete):
10101025
Upper limit (upper > lower).
10111026
"""
10121027

1013-
def __init__(self, lower, upper, *args, **kwargs):
1014-
super().__init__(*args, **kwargs)
1015-
self.lower = intX(at.floor(lower))
1016-
self.upper = intX(at.floor(upper))
1017-
self.mode = at.maximum(intX(at.floor((upper + lower) / 2.0)), self.lower)
1018-
1019-
def _random(self, lower, upper, size=None):
1020-
# This way seems to be the only to deal with lower and upper
1021-
# as array-like.
1022-
samples = stats.randint.rvs(lower, upper + 1, size=size)
1023-
return samples
1024-
1025-
def random(self, point=None, size=None):
1026-
r"""
1027-
Draw random values from DiscreteUniform distribution.
1028-
1029-
Parameters
1030-
----------
1031-
point: dict, optional
1032-
Dict of variable values on which random values are to be
1033-
conditioned (uses default point if not specified).
1034-
size: int, optional
1035-
Desired size of random sample (returns one sample if not
1036-
specified).
1028+
rv_op = discrete_uniform
10371029

1038-
Returns
1039-
-------
1040-
array
1041-
"""
1042-
# lower, upper = draw_values([self.lower, self.upper], point=point, size=size)
1043-
# return generate_samples(self._random, lower, upper, dist_shape=self.shape, size=size)
1030+
@classmethod
1031+
def dist(cls, lower, upper, *args, **kwargs):
1032+
lower = intX(at.floor(lower))
1033+
upper = intX(at.floor(upper))
1034+
return super().dist([lower, upper], **kwargs)
10441035

1045-
def logp(self, value):
1036+
def logp(value, lower, upper):
10461037
r"""
10471038
Calculate log-probability of DiscreteUniform distribution at specified value.
10481039
@@ -1056,15 +1047,13 @@ def logp(self, value):
10561047
-------
10571048
TensorVariable
10581049
"""
1059-
upper = self.upper
1060-
lower = self.lower
10611050
return bound(
10621051
at.fill(value, -at.log(upper - lower + 1)),
10631052
lower <= value,
10641053
value <= upper,
10651054
)
10661055

1067-
def logcdf(self, value):
1056+
def logcdf(value, lower, upper):
10681057
"""
10691058
Compute the log of the cumulative distribution function for Discrete uniform distribution
10701059
at the specified value.
@@ -1079,8 +1068,6 @@ def logcdf(self, value):
10791068
-------
10801069
TensorVariable
10811070
"""
1082-
upper = self.upper
1083-
lower = self.lower
10841071

10851072
return bound(
10861073
at.switch(

pymc3/tests/test_distributions.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -929,7 +929,6 @@ def test_bound_normal(self):
929929
x = PositiveNormal("x", mu=0, sigma=1, transform=None)
930930
assert np.isinf(logpt(x, -1).eval())
931931

932-
@pytest.mark.xfail(reason="Distribution not refactored yet")
933932
def test_discrete_unif(self):
934933
self.check_logp(
935934
DiscreteUniform,
@@ -2817,17 +2816,16 @@ def test_issue_3051(self, dims, dist_cls, kwargs):
28172816
assert isinstance(actual_a, np.ndarray)
28182817
assert actual_a.shape == (X.shape[0],)
28192818

2820-
@pytest.mark.xfail(reason="Distribution not refactored yet")
28212819
def test_issue_4499(self):
28222820
# Test for bug in Uniform and DiscreteUniform logp when setting check_bounds = False
28232821
# https://github.com/pymc-devs/pymc3/issues/4499
28242822
with pm.Model(check_bounds=False) as m:
28252823
x = pm.Uniform("x", 0, 2, shape=10, transform=None)
2826-
assert_almost_equal(m.logp_array(np.ones(10)), -np.log(2) * 10)
2824+
assert_almost_equal(m.logp({"x": np.ones(10)}), -np.log(2) * 10)
28272825

28282826
with pm.Model(check_bounds=False) as m:
2829-
x = pm.DiscreteUniform("x", 0, 1, shape=10)
2830-
assert_almost_equal(m.logp_array(np.ones(10)), -np.log(2) * 10)
2827+
x = pm.DiscreteUniform("x", 0, 1, size=10)
2828+
assert_almost_equal(m.logp({"x": np.ones(10)}), -np.log(2) * 10)
28312829

28322830

28332831
@pytest.mark.xfail(reason="DensityDist no longer supported")

pymc3/tests/test_distributions_random.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
Domain,
4343
I,
4444
Nat,
45-
NatSmall,
4645
PdMatrix,
4746
PdMatrixChol,
4847
R,
@@ -339,12 +338,6 @@ class TestZeroInflatedBinomial(BaseTestCases.BaseTestCase):
339338
params = {"n": 10, "p": 0.6, "psi": 0.3}
340339

341340

342-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
343-
class TestDiscreteUniform(BaseTestCases.BaseTestCase):
344-
distribution = pm.DiscreteUniform
345-
params = {"lower": 0.0, "upper": 10.0}
346-
347-
348341
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
349342
class TestMoyal(BaseTestCases.BaseTestCase):
350343
distribution = pm.Moyal
@@ -907,6 +900,24 @@ class TestBetaBinomial(BaseTestDistribution):
907900
]
908901

909902

903+
class TestDiscreteUniform(BaseTestDistribution):
904+
def discrete_uniform_rng_fn(self, size, lower, upper, rng):
905+
return st.randint.rvs(lower, upper + 1, size=size, random_state=rng)
906+
907+
pymc_dist = pm.DiscreteUniform
908+
pymc_dist_params = {"lower": -1, "upper": 9}
909+
expected_rv_op_params = {"lower": -1, "upper": 9}
910+
reference_dist_params = {"lower": -1, "upper": 9}
911+
reference_dist = lambda self: functools.partial(
912+
self.discrete_uniform_rng_fn, rng=self.get_random_state()
913+
)
914+
tests_to_run = [
915+
"check_pymc_params_match_rv_op",
916+
"check_pymc_draws_match_reference",
917+
"check_rv_size",
918+
]
919+
920+
910921
class TestScalarParameterSamples(SeededTest):
911922
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
912923
def test_bounded(self):
@@ -1016,15 +1027,6 @@ def test_half_flat(self):
10161027
with pytest.raises(ValueError):
10171028
f.random(1)
10181029

1019-
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
1020-
def test_discrete_uniform(self):
1021-
def ref_rand(size, lower, upper):
1022-
return st.randint.rvs(lower, upper + 1, size=size)
1023-
1024-
pymc3_random_discrete(
1025-
pm.DiscreteUniform, {"lower": -NatSmall, "upper": NatSmall}, ref_rand=ref_rand
1026-
)
1027-
10281030
@pytest.mark.xfail(reason="This distribution has not been refactored for v4")
10291031
def test_constant_dist(self):
10301032
def ref_rand(size, c):

0 commit comments

Comments
 (0)