Skip to content

Commit 84c1fe1

Browse files
authored
Merge pull request #77 from honno/parametrize-np
Support testing `np` and `tnp` on the same test
2 parents 1272d80 + 352154a commit 84c1fe1

File tree

1 file changed

+51
-4
lines changed

1 file changed

+51
-4
lines changed

torch_np/tests/conftest.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44

5+
import torch_np as tnp
6+
57

68
def pytest_configure(config):
79
config.addinivalue_line("markers", "slow: very slow tests")
@@ -12,14 +14,59 @@ def pytest_addoption(parser):
1214
parser.addoption("--nonp", action="store_true", help="error when NumPy is accessed")
1315

1416

17+
class Inaccessible:
18+
def __getattribute__(self, attr):
19+
raise RuntimeError(f"Using --nonp but accessed np.{attr}")
20+
21+
1522
def pytest_sessionstart(session):
1623
if session.config.getoption("--nonp"):
24+
sys.modules["numpy"] = Inaccessible()
1725

18-
class Inaccessible:
19-
def __getattribute__(self, attr):
20-
raise RuntimeError(f"Using --nonp but accessed np.{attr}")
2126

22-
sys.modules["numpy"] = Inaccessible()
27+
def pytest_generate_tests(metafunc):
28+
"""
29+
Hook to parametrize test cases
30+
See https://docs.pytest.org/en/6.2.x/parametrize.html#pytest-generate-tests
31+
32+
The logic here allows us to test with both NumPy-proper and torch_np.
33+
Normally we'd just test torch_np, e.g.
34+
35+
import torch_np as np
36+
...
37+
def test_foo():
38+
np.array([42])
39+
...
40+
41+
but this hook allows us to test NumPy-proper as well, e.g.
42+
43+
def test_foo(np):
44+
np.array([42])
45+
...
46+
47+
np is a pytest parameter, which is either NumPy-proper or torch_np. This
48+
allows us to sanity check our own tests, so that tested behaviour is
49+
consistent with NumPy-proper.
50+
51+
pytest will have test names respective to the library being tested, e.g.
52+
53+
$ pytest --collect-only
54+
test_foo[torch_np]
55+
test_foo[numpy]
56+
57+
"""
58+
np_params = [tnp]
59+
60+
try:
61+
import numpy as np
62+
except ImportError:
63+
pass
64+
else:
65+
if not isinstance(np, Inaccessible): # i.e. --nonp was used
66+
np_params.append(np)
67+
68+
if "np" in metafunc.fixturenames:
69+
metafunc.parametrize("np", np_params)
2370

2471

2572
def pytest_collection_modifyitems(config, items):

0 commit comments

Comments
 (0)