Skip to content

Commit e07a71a

Browse files
Refactored pymc3_random_discrete to use histograms instead of raw counts (#4840)
* Refactored pymc3_random_discrete to use histograms instead of raw counts * Run pre-commit Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent d2b1ed2 commit e07a71a

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

pymc3/tests/test_distributions_random.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,15 +107,12 @@ def pymc3_random_discrete(
107107
e = ref_rand(size=size, **pt)
108108
o = np.atleast_1d(o).flatten()
109109
e = np.atleast_1d(e).flatten()
110-
observed = dict(zip(*np.unique(o, return_counts=True)))
111-
expected = dict(zip(*np.unique(e, return_counts=True)))
112-
for e in expected.keys():
113-
expected[e] = (observed.get(e, 0), expected[e])
114-
k = np.array([v for v in expected.values()])
115-
if np.all(k[:, 0] == k[:, 1]):
110+
observed, _ = np.histogram(o, bins=min(7, len(set(o))))
111+
expected, _ = np.histogram(e, bins=min(7, len(set(o))))
112+
if np.all(observed == expected):
116113
p = 1.0
117114
else:
118-
_, p = st.chisquare(k[:, 0], k[:, 1])
115+
_, p = st.chisquare(observed + 1, expected + 1)
119116
f -= 1
120117
assert p > alpha, str(pt)
121118

0 commit comments

Comments
 (0)