Skip to content

Commit be6286e

Browse files
committed
Support testing np and tnp on the same test
1 parent 1272d80 commit be6286e

File tree

1 file changed

+50
-4
lines changed

1 file changed

+50
-4
lines changed

torch_np/tests/conftest.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import pytest
44

5+
import torch_np as tnp
6+
57

68
def pytest_configure(config):
79
config.addinivalue_line("markers", "slow: very slow tests")
@@ -12,14 +14,58 @@ def pytest_addoption(parser):
1214
parser.addoption("--nonp", action="store_true", help="error when NumPy is accessed")
1315

1416

17+
class Inaccessible:
18+
def __getattribute__(self, attr):
19+
raise RuntimeError(f"Using --nonp but accessed np.{attr}")
20+
21+
1522
def pytest_sessionstart(session):
1623
if session.config.getoption("--nonp"):
24+
sys.modules["numpy"] = Inaccessible()
1725

18-
class Inaccessible:
19-
def __getattribute__(self, attr):
20-
raise RuntimeError(f"Using --nonp but accessed np.{attr}")
2126

22-
sys.modules["numpy"] = Inaccessible()
27+
def pytest_generate_tests(metafunc):
28+
"""
29+
Hook to test with both NumPy-proper and torch_np
30+
31+
Normally we'd just test torch_np, e.g.
32+
33+
import torch_np as np
34+
...
35+
def test_foo():
36+
np.array([42])
37+
...
38+
39+
But this hook allows us to test both NumPy-proper as well, e.g.
40+
41+
def test_foo(np):
42+
np.array([42])
43+
...
44+
45+
np is a pytest parameter, which is either NumPy-proper or torch_np. This
46+
allows us to sanity check our own tests, so that tested behaviour is
47+
consistent with NumPy-proper.
48+
49+
pytest will have test names respective to the library being tested, e.g.
50+
51+
$ pytest --collect-only
52+
test_foo[torch_np]
53+
test_foo[numpy]
54+
55+
See https://docs.pytest.org/en/6.2.x/parametrize.html#pytest-generate-tests
56+
"""
57+
np_params = [tnp]
58+
59+
try:
60+
import numpy as np
61+
except ImportError:
62+
pass
63+
else:
64+
if not isinstance(np, Inaccessible): # i.e. --nonp was used
65+
np_params.append(np)
66+
67+
if "np" in metafunc.fixturenames:
68+
metafunc.parametrize("np", np_params)
2369

2470

2571
def pytest_collection_modifyitems(config, items):

0 commit comments

Comments
 (0)