Skip to content

Commit 2737206

Browse files
committed
Fix tests
1 parent 88cacc4 commit 2737206

File tree

4 files changed

+37
-22
lines changed

4 files changed

+37
-22
lines changed

array_api_compat/_internal.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,29 @@ def wrapped_f(*args, **kwargs):
4444
return inner
4545

4646

47-
def _get_all_public_members(module, filter_=None):
48-
"""Get all public members of a module."""
49-
try:
50-
return getattr(module, '__all__')
51-
except AttributeError:
52-
pass
47+
def _get_all_public_members(module, exclude=None, extend_all=False):
48+
"""Get all public members of a module.
5349
54-
if filter_ is None:
55-
filter_ = lambda name: name.startswith('_') # noqa: E731
50+
Parameters
51+
----------
52+
module : module
53+
The module to get members from.
54+
exclude : callable, optional
55+
A callable that takes a name and returns True if the name should be
56+
excluded from the list of members.
57+
extend_all : bool, optional
58+
If True, extend the module's __all__ attribute with the members of the
59+
module derive from dir(module)
60+
"""
61+
members = getattr(module, '__all__', [])
62+
63+
if members and not extend_all:
64+
return members
65+
66+
if exclude is None:
67+
exclude = lambda name: name.startswith('_') # noqa: E731
68+
69+
members += [_ for _ in dir(module) if not exclude(_)]
5670

57-
return map(filter_, dir(module))
71+
# remove duplicates
72+
return list(set(members))

array_api_compat/torch/__init__.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
11
# Several names are not included in the above import *
2-
import torch
2+
import torch as _torch
33
from torch import * # noqa: F401, F403
44

55
from .._internal import _get_all_public_members
66

77

8-
def filter_(name):
8+
def exlcude(name):
99
if (
1010
name.startswith("_")
1111
or name.endswith("_")
1212
or "cuda" in name
1313
or "cpu" in name
1414
or "backward" in name
1515
):
16-
return False
17-
return True
16+
return True
17+
return False
1818

1919

20-
_torch_all = _get_all_public_members(torch, filter_=filter_)
20+
_torch_all = _get_all_public_members(_torch, exclude=exlcude, extend_all=True)
2121

2222
for _name in _torch_all:
23-
globals()[_name] = getattr(torch, _name)
23+
globals()[_name] = getattr(_torch, _name)
2424

2525

2626
from ..common._helpers import ( # noqa: E402

tests/test_array_namespace.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55

66

77
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
8-
@pytest.mark.parametrize("api_version", [None, '2021.12'])
8+
@pytest.mark.parametrize("api_version", [None, "2021.12"])
99
def test_array_namespace(library, api_version):
1010
lib = pytest.importorskip(library)
1111

1212
array = lib.asarray([1.0, 2.0, 3.0])
1313
namespace = array_api_compat.array_namespace(array, api_version=api_version)
1414

15-
if 'array_api' in library:
15+
if "array_api" in library:
1616
assert namespace == lib
1717
else:
1818
assert namespace == getattr(array_api_compat, library)
1919

2020

2121
def test_array_namespace_errors():
2222
np = pytest.importorskip("numpy")
23-
23+
2424
pytest.raises(TypeError, lambda: array_namespace([1]))
2525
pytest.raises(TypeError, lambda: array_namespace())
2626

@@ -31,10 +31,12 @@ def test_array_namespace_errors():
3131

3232
def test_array_namespace_errors_torch():
3333
torch = pytest.importorskip("torch")
34-
34+
np = pytest.importorskip("numpy")
35+
3536
y = torch.asarray([1, 2])
37+
x = np.asarray([1, 2])
3638
pytest.raises(TypeError, lambda: array_namespace(x, y))
37-
pytest.raises(ValueError, lambda: array_namespace(x, api_version='2022.12'))
39+
pytest.raises(ValueError, lambda: array_namespace(x, api_version="2022.12"))
3840

3941

4042
def test_get_namespace():

tests/test_isdtype.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55

66
import pytest
77

8-
from ._helpers import import_
9-
108
# Check the known dtypes by their string names
119

1210
def _spec_dtypes(library):

0 commit comments

Comments
 (0)