Skip to content

Commit a41d524

Browse files
committed
Bring back testing utilities used in downstream packages
Follow up to * 534a9ae * e1d36ca
1 parent 49aacf4 commit a41d524

37 files changed

+207
-194
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ jobs:
4747
tests/test_func_utils.py
4848
tests/distributions/test_shape_utils.py
4949
tests/distributions/test_mixture.py
50+
tests/test_testing.py
5051
5152
- |
5253
tests/distributions/test_continuous.py

docs/source/contributing/implementing_distribution.md

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,11 @@ Most tests can be accommodated by the default `BaseTestDistributionRandom` class
240240
1. Shape variable inference is correct, via `check_rv_size`
241241

242242
```python
243-
from tests.distributions.util import BaseTestDistributionRandom, seeded_scipy_distribution_builder
244243

245-
class TestBlah(BaseTestDistributionRandom):
244+
from pymc.testing import BaseTestDistributionRandom, seeded_scipy_distribution_builder
245+
246246

247+
class TestBlah(BaseTestDistributionRandom):
247248
pymc_dist = pm.Blah
248249
# Parameters with which to test the blah pymc Distribution
249250
pymc_dist_params = {"param1": 0.25, "param2": 2.0}
@@ -311,38 +312,36 @@ Tests for the `logp` and `logcdf` mostly make use of the helpers `check_logp`, `
311312
`check_selfconsistency_discrete_logcdf` implemented in `~tests.distributions.util`
312313

313314
```python
314-
from tests.helpers import select_by_precision
315-
from tests.distributions.util import check_logp, check_logcdf, Domain
315+
316+
from pymc.testing import Domain, check_logp, check_logcdf, select_by_precision
316317

317318
R = Domain([-np.inf, -2.1, -1, -0.01, 0.0, 0.01, 1, 2.1, np.inf])
318319
Rplus = Domain([0, 0.01, 0.1, 0.9, 0.99, 1, 1.5, 2, 100, np.inf])
319320

320321

321-
322322
def test_blah():
323-
324-
check_logp(
325-
pymc_dist=pm.Blah,
326-
# Domain of the distribution values
327-
domain=R,
328-
# Domains of the distribution parameters
329-
paramdomains={"mu": R, "sigma": Rplus},
330-
# Reference scipy (or other) logp function
331-
scipy_logp = lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma),
332-
# Number of decimal points expected to match between the pymc and reference functions
333-
decimal=select_by_precision(float64=6, float32=3),
334-
# Maximum number of combinations of domain * paramdomains to test
335-
n_samples=100,
336-
)
337-
338-
check_logcdf(
339-
pymc_dist=pm.Blah,
340-
domain=R,
341-
paramdomains={"mu": R, "sigma": Rplus},
342-
scipy_logcdf=lambda value, mu, sigma: sp.norm.logcdf(value, mu, sigma),
343-
decimal=select_by_precision(float64=6, float32=1),
344-
n_samples=-1,
345-
)
323+
check_logp(
324+
pymc_dist=pm.Blah,
325+
# Domain of the distribution values
326+
domain=R,
327+
# Domains of the distribution parameters
328+
paramdomains={"mu": R, "sigma": Rplus},
329+
# Reference scipy (or other) logp function
330+
scipy_logp=lambda value, mu, sigma: sp.norm.logpdf(value, mu, sigma),
331+
# Number of decimal points expected to match between the pymc and reference functions
332+
decimal=select_by_precision(float64=6, float32=3),
333+
# Maximum number of combinations of domain * paramdomains to test
334+
n_samples=100,
335+
)
336+
337+
check_logcdf(
338+
pymc_dist=pm.Blah,
339+
domain=R,
340+
paramdomains={"mu": R, "sigma": Rplus},
341+
scipy_logcdf=lambda value, mu, sigma: sp.norm.logcdf(value, mu, sigma),
342+
decimal=select_by_precision(float64=6, float32=1),
343+
n_samples=-1,
344+
)
346345

347346
```
348347

@@ -382,7 +381,8 @@ which checks if:
382381

383382
import pytest
384383
from pymc.distributions import Blah
385-
from tests.distributions.util import assert_moment_is_expected
384+
from pymc.testing import assert_moment_is_expected
385+
386386

387387
@pytest.mark.parametrize(
388388
"param1, param2, size, expected",

0 commit comments

Comments
 (0)