Skip to content

Commit c8e9d84

Browse files
committed
Readd TestFlat
1 parent 8ebb09a commit c8e9d84

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

pymc/tests/test_distributions_random.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,29 @@ def seeded_numpy_distribution_builder(dist_name: str) -> Callable:
314314
)
315315

316316

317+
class TestFlat(BaseTestDistributionRandom):
318+
pymc_dist = pm.Flat
319+
pymc_dist_params = {}
320+
expected_rv_op_params = {}
321+
checks_to_run = [
322+
"check_pymc_params_match_rv_op",
323+
"check_rv_inferred_size",
324+
"check_not_implemented",
325+
]
326+
327+
def check_rv_inferred_size(self):
328+
sizes_to_check = [None, (), 1, (1,), 5, (4, 5), (2, 4, 2)]
329+
sizes_expected = [(), (), (1,), (1,), (5,), (4, 5), (2, 4, 2)]
330+
for size, expected in zip(sizes_to_check, sizes_expected):
331+
pymc_rv = self.pymc_dist.dist(**self.pymc_dist_params, size=size)
332+
expected_symbolic = tuple(pymc_rv.shape.eval())
333+
assert expected_symbolic == expected
334+
335+
def check_not_implemented(self):
336+
with pytest.raises(NotImplementedError):
337+
self.pymc_rv.eval()
338+
339+
317340
class TestHalfFlat(BaseTestDistributionRandom):
318341
pymc_dist = pm.HalfFlat
319342
pymc_dist_params = {}

0 commit comments

Comments
 (0)