Skip to content

Commit ae9fb7c

Browse files
committed
Add ref_rand helper function. Clip lower in logp
1 parent 3426e36 commit ae9fb7c

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

pymc3/distributions/discrete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,7 @@ def logp(self, value):
913913
- betaln(n - value + 1, bad - n + value + 1)
914914
- betaln(tot + 1, 1)
915915
)
916-
lower = tt.switch(tt.gt(n - N + k, 0), n - N + k, 0)
916+
lower = tt.clip(n - N + k, 0, n - N + k)
917917
upper = tt.switch(tt.lt(k, n), k, n)
918918
nonint_value = (value != intX(tt.floor(value)))
919919
return bound(result, lower <= value, value <= upper, nonint_value)

pymc3/tests/test_distributions_random.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,7 +745,16 @@ def test_geometric(self):
745745
pymc3_random_discrete(pm.Geometric, {"p": Unit}, size=500, fails=50, ref_rand=nr.geometric)
746746

747747
def test_hypergeometric(self):
748-
pymc3_random_discrete(pm.HyperGeometric, {"N": Nat, "n": Nat, "k": Nat}, size=500, fails=50, ref_rand=nr.hypergeometric)
748+
def ref_rand(size, N, n, k):
749+
return nr.hypergeometric(ngood=k, nbad=N-k, nsample=n, size=size)
750+
751+
pymc3_random_discrete(
752+
pm.HyperGeometric,
753+
{"N": Nat, "n": Nat, "k": Nat},
754+
size=500,
755+
fails=50,
756+
ref_rand=ref_rand,
757+
)
749758

750759
def test_discrete_uniform(self):
751760
def ref_rand(size, lower, upper):

0 commit comments

Comments
 (0)