Skip to content

Commit 164a8f2

Browse files
committed
Fix bug. Now pymc3_matches_scipy runs without error but pymc3_random_discrete diverges from expected value
1 parent ae9fb7c commit 164a8f2

File tree

3 files changed

+13
-18
lines changed

3 files changed

+13
-18
lines changed

pymc3/distributions/discrete.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -851,10 +851,10 @@ class HyperGeometric(Discrete):
851851
----------
852852
N : integer
853853
Total size of the population
854-
n : integer
855-
Number of samples drawn from the population
856854
k : integer
857855
Number of successful individuals in the population
856+
n : integer
857+
Number of samples drawn from the population
858858
"""
859859

860860
def __init__(self, N, k, n, *args, **kwargs):
@@ -881,10 +881,8 @@ def random(self, point=None, size=None):
881881
-------
882882
array
883883
"""
884-
N, n, k = draw_values([self.N, self.n, self.k], point=point, size=size)
885-
return generate_samples(
886-
np.random.hypergeometric, N, n, k, dist_shape=self.shape, size=size
887-
)
884+
N, k, n = draw_values([self.N, self.k, self.n], point=point, size=size)
885+
return generate_samples(np.random.hypergeometric, N, k, n, dist_shape=self.shape, size=size)
888886

889887
def logp(self, value):
890888
r"""
@@ -913,10 +911,7 @@ def logp(self, value):
913911
- betaln(n - value + 1, bad - n + value + 1)
914912
- betaln(tot + 1, 1)
915913
)
916-
lower = tt.clip(n - N + k, 0, n - N + k)
917-
upper = tt.switch(tt.lt(k, n), k, n)
918-
nonint_value = (value != intX(tt.floor(value)))
919-
return bound(result, lower <= value, value <= upper, nonint_value)
914+
return result
920915

921916

922917
class DiscreteUniform(Discrete):

pymc3/tests/test_distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -795,8 +795,8 @@ def test_hypergeometric(self):
795795
self.pymc3_matches_scipy(
796796
HyperGeometric,
797797
Nat,
798-
{"N": NatSmall, "n": NatSmall, "k": NatSmall},
799-
lambda value, N, n, k: sp.hypergeom.logpmf(value, N, k, n),
798+
{"N": NatSmall, "k": NatSmall, "n": NatSmall},
799+
lambda value, N, k, n: sp.hypergeom.logpmf(value, N, k, n),
800800
)
801801

802802
def test_negative_binomial(self):

pymc3/tests/test_distributions_random.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -745,14 +745,14 @@ def test_geometric(self):
745745
pymc3_random_discrete(pm.Geometric, {"p": Unit}, size=500, fails=50, ref_rand=nr.geometric)
746746

747747
def test_hypergeometric(self):
748-
def ref_rand(size, N, n, k):
749-
return nr.hypergeometric(ngood=k, nbad=N-k, nsample=n, size=size)
748+
def ref_rand(size, N, k, n):
749+
return st.hypergeom.rvs(M=N, n=k, N=n, size=size)
750750

751751
pymc3_random_discrete(
752-
pm.HyperGeometric,
753-
{"N": Nat, "n": Nat, "k": Nat},
754-
size=500,
755-
fails=50,
752+
pm.HyperGeometric,
753+
{"N": Nat, "k": Nat, "n": Nat},
754+
size=100,
755+
fails=50,
756756
ref_rand=ref_rand,
757757
)
758758

0 commit comments

Comments
 (0)