Skip to content

Commit 9294a4b

Browse files
committed
Use importorskip
1 parent 5923a1d commit 9294a4b

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

tests/test_array_namespace.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1+
import pytest
2+
13
import array_api_compat
24
from array_api_compat import array_namespace
35

4-
from ._helpers import import_
5-
6-
import pytest
7-
86

97
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
108
@pytest.mark.parametrize("api_version", [None, '2021.12'])
119
def test_array_namespace(library, api_version):
12-
lib = import_(library)
10+
lib = pytest.importorskip(library)
1311

1412
array = lib.asarray([1.0, 2.0, 3.0])
1513
namespace = array_api_compat.array_namespace(array, api_version=api_version)
@@ -21,20 +19,21 @@ def test_array_namespace(library, api_version):
2119

2220

2321
def test_array_namespace_errors():
22+
np = pytest.importorskip("numpy")
23+
2424
pytest.raises(TypeError, lambda: array_namespace([1]))
2525
pytest.raises(TypeError, lambda: array_namespace())
2626

27-
import numpy as np
2827
x = np.asarray([1, 2])
29-
3028
pytest.raises(TypeError, lambda: array_namespace((x, x)))
3129
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
3230

33-
import torch
34-
y = torch.asarray([1, 2])
3531

32+
def test_array_namespace_errors_torch():
33+
torch = pytest.importorskip("torch")
34+
35+
y = torch.asarray([1, 2])
3636
pytest.raises(TypeError, lambda: array_namespace(x, y))
37-
3837
pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12'))
3938

4039

tests/test_common.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from ._helpers import import_
21
from array_api_compat import to_device
32

43
import pytest
@@ -11,7 +10,7 @@ def test_to_device_host(library):
1110
# for DtoH transfers; ensure that we support a portable
1211
# shim for common array libs
1312
# see: https://github.com/scipy/scipy/issues/18286#issuecomment-1527552919
14-
xp = import_('array_api_compat.' + library)
13+
xp = pytest.importorskip('array_api_compat.' + library)
1514
expected = np.array([1, 2, 3])
1615
x = xp.asarray([1, 2, 3])
1716
x = to_device(x, "cpu")

tests/test_isdtype.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
non-spec dtypes
44
"""
55

6-
from ._helpers import import_
7-
86
import pytest
97

8+
from ._helpers import import_
9+
1010
# Check the known dtypes by their string names
1111

1212
def _spec_dtypes(library):
@@ -66,7 +66,7 @@ def isdtype_(dtype_, kind):
6666

6767
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
6868
def test_isdtype_spec_dtypes(library):
69-
xp = import_('array_api_compat.' + library)
69+
xp = pytest.importorskip('array_api_compat.' + library)
7070

7171
isdtype = xp.isdtype
7272

@@ -101,7 +101,7 @@ def test_isdtype_spec_dtypes(library):
101101
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
102102
@pytest.mark.parametrize("dtype_", additional_dtypes)
103103
def test_isdtype_additional_dtypes(library, dtype_):
104-
xp = import_('array_api_compat.' + library)
104+
xp = pytest.importorskip('array_api_compat.' + library)
105105

106106
isdtype = xp.isdtype
107107

0 commit comments

Comments
 (0)