Skip to content

Commit 7b99449

Browse files
restore code
1 parent c5b82db commit 7b99449

File tree

6 files changed

+43
-64
lines changed

6 files changed

+43
-64
lines changed

tests/_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33

44
import pytest
55

6-
wrapped_libraries = ["numpy", "paddle", "torch"]
7-
all_libraries = wrapped_libraries + []
6+
wrapped_libraries = ["numpy", "cupy", "torch", "dask.array", "paddle"]
7+
all_libraries = wrapped_libraries + ["jax.numpy"]
88

99
# `sparse` added array API support as of Python 3.10.
10-
# if sys.version_info >= (3, 10):
11-
# all_libraries.append('sparse')
10+
if sys.version_info >= (3, 10):
11+
all_libraries.append('sparse')
1212

1313
def import_(library, wrapper=False):
1414
if library == 'cupy':

tests/test_all.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,5 @@ def test_all(library):
4040
all_names = module.__all__
4141

4242
if set(dir_names) != set(all_names):
43-
assert set(dir_names) - set(all_names) == set(), f"Failed in library '{library}', some dir() names not included in __all__ for {mod_name}"
44-
assert set(all_names) - set(dir_names) == set(), f"Failed in library '{library}', some __all__ names not in dir() for {mod_name}"
43+
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
44+
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"

tests/test_array_namespace.py

Lines changed: 17 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# import jax
66
import numpy as np
77
import pytest
8-
# import torch
8+
import torch
99
import paddle
1010

1111
import array_api_compat
@@ -73,11 +73,11 @@ def test_array_namespace(library, api_version, use_compat):
7373
"""
7474
subprocess.run([sys.executable, "-c", code], check=True)
7575

76-
# def test_jax_zero_gradient():
77-
# jx = jax.numpy.arange(4)
78-
# jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
79-
# assert (array_api_compat.get_namespace(jax_zero) is
80-
# array_api_compat.get_namespace(jx))
76+
def test_jax_zero_gradient():
77+
jx = jax.numpy.arange(4)
78+
jax_zero = jax.vmap(jax.grad(jax.numpy.float32, allow_int=True))(jx)
79+
assert (array_api_compat.get_namespace(jax_zero) is
80+
array_api_compat.get_namespace(jx))
8181

8282
def test_array_namespace_errors():
8383
pytest.raises(TypeError, lambda: array_namespace([1]))
@@ -87,53 +87,32 @@ def test_array_namespace_errors():
8787
pytest.raises(TypeError, lambda: array_namespace((x, x)))
8888
pytest.raises(TypeError, lambda: array_namespace(x, (x, x)))
8989

90-
# def test_array_namespace_errors_torch():
91-
# y = torch.asarray([1, 2])
92-
# x = np.asarray([1, 2])
93-
# pytest.raises(TypeError, lambda: array_namespace(x, y))
90+
def test_array_namespace_errors_torch():
91+
y = torch.asarray([1, 2])
92+
x = np.asarray([1, 2])
93+
pytest.raises(TypeError, lambda: array_namespace(x, y))
9494

9595

9696
def test_array_namespace_errors_paddle():
9797
y = paddle.to_tensor([1, 2])
9898
x = np.asarray([1, 2])
9999
pytest.raises(TypeError, lambda: array_namespace(x, y))
100100

101-
102-
# def test_api_version():
103-
# x = torch.asarray([1, 2])
104-
# torch_ = import_("torch", wrapper=True)
105-
# assert array_namespace(x, api_version="2023.12") == torch_
106-
# assert array_namespace(x, api_version=None) == torch_
107-
# assert array_namespace(x) == torch_
108-
# # Should issue a warning
109-
# with warnings.catch_warnings(record=True) as w:
110-
# assert array_namespace(x, api_version="2021.12") == torch_
111-
# assert len(w) == 1
112-
# assert "2021.12" in str(w[0].message)
113-
114-
# # Should issue a warning
115-
# with warnings.catch_warnings(record=True) as w:
116-
# assert array_namespace(x, api_version="2022.12") == torch_
117-
# assert len(w) == 1
118-
# assert "2022.12" in str(w[0].message)
119-
120-
# pytest.raises(ValueError, lambda: array_namespace(x, api_version="2020.12"))
121-
122101
def test_api_version():
123-
x = paddle.asarray([1, 2])
124-
paddle_ = import_("paddle", wrapper=True)
125-
assert array_namespace(x, api_version="2023.12") == paddle_
126-
assert array_namespace(x, api_version=None) == paddle_
127-
assert array_namespace(x) == paddle_
102+
x = torch.asarray([1, 2])
103+
torch_ = import_("torch", wrapper=True)
104+
assert array_namespace(x, api_version="2023.12") == torch_
105+
assert array_namespace(x, api_version=None) == torch_
106+
assert array_namespace(x) == torch_
128107
# Should issue a warning
129108
with warnings.catch_warnings(record=True) as w:
130-
assert array_namespace(x, api_version="2021.12") == paddle_
109+
assert array_namespace(x, api_version="2021.12") == torch_
131110
assert len(w) == 1
132111
assert "2021.12" in str(w[0].message)
133112

134113
# Should issue a warning
135114
with warnings.catch_warnings(record=True) as w:
136-
assert array_namespace(x, api_version="2022.12") == paddle_
115+
assert array_namespace(x, api_version="2022.12") == torch_
137116
assert len(w) == 1
138117
assert "2022.12" in str(w[0].message)
139118

tests/test_common.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,21 @@
1616

1717
is_array_functions = {
1818
'numpy': 'is_numpy_array',
19-
# 'cupy': 'is_cupy_array',
19+
'cupy': 'is_cupy_array',
2020
'torch': 'is_torch_array',
21-
# 'dask.array': 'is_dask_array',
22-
# 'jax.numpy': 'is_jax_array',
23-
# 'sparse': 'is_pydata_sparse_array',
21+
'dask.array': 'is_dask_array',
22+
'jax.numpy': 'is_jax_array',
23+
'sparse': 'is_pydata_sparse_array',
2424
'paddle': 'is_paddle_array',
2525
}
2626

2727
is_namespace_functions = {
2828
'numpy': 'is_numpy_namespace',
29-
# 'cupy': 'is_cupy_namespace',
29+
'cupy': 'is_cupy_namespace',
3030
'torch': 'is_torch_namespace',
31-
# 'dask.array': 'is_dask_namespace',
32-
# 'jax.numpy': 'is_jax_namespace',
33-
# 'sparse': 'is_pydata_sparse_namespace',
31+
'dask.array': 'is_dask_namespace',
32+
'jax.numpy': 'is_jax_namespace',
33+
'sparse': 'is_pydata_sparse_namespace',
3434
'paddle': 'is_paddle_namespace',
3535
}
3636

tests/test_no_dependencies.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def _test_dependency(mod):
5151

5252
@pytest.mark.parametrize("library",
5353
[
54-
"numpy",
55-
"paddle", "array_api_strict",
54+
"numpy", "cupy", "numpy", "torch", "dask.array",
55+
"jax.numpy", "sparse", "paddle", "array_api_strict"
5656
]
5757
)
5858
def test_numpy_dependency(library):

tests/test_vendoring.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,23 +7,23 @@ def test_vendoring_numpy():
77
uses_numpy._test_numpy()
88

99

10-
# def test_vendoring_cupy():
11-
# pytest.importorskip("cupy")
10+
def test_vendoring_cupy():
11+
pytest.importorskip("cupy")
1212

13-
# from vendor_test import uses_cupy
13+
from vendor_test import uses_cupy
1414

15-
# uses_cupy._test_cupy()
15+
uses_cupy._test_cupy()
1616

1717

18-
# def test_vendoring_torch():
19-
# from vendor_test import uses_torch
18+
def test_vendoring_torch():
19+
from vendor_test import uses_torch
2020

21-
# uses_torch._test_torch()
21+
uses_torch._test_torch()
2222

2323

24-
# def test_vendoring_dask():
25-
# from vendor_test import uses_dask
26-
# uses_dask._test_dask()
24+
def test_vendoring_dask():
25+
from vendor_test import uses_dask
26+
uses_dask._test_dask()
2727

2828

2929
def test_vendoring_paddle():

0 commit comments

Comments
 (0)