Skip to content

Commit 2c4157c

Browse files
committed
Calls __array_namespace__ when input already adopts ArrayAPI spec
1 parent 945609e commit 2c4157c

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

array_api_compat/common/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_namespace(*xs, _use_compat=True):
6060
if isinstance(x, (tuple, list)):
6161
namespaces.add(get_namespace(*x, _use_compat=_use_compat))
6262
elif hasattr(x, '__array_namespace__'):
63-
namespaces.add(x.__array_namespace__)
63+
namespaces.add(x.__array_namespace__())
6464
elif _is_numpy_array(x):
6565
if _use_compat:
6666
from .. import numpy as numpy_namespace

tests/test_get_namespace.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,10 @@ def test_get_namespace(library):
1212
expected_namespace = getattr(array_api_compat, library)
1313
assert namespace is expected_namespace
1414

15+
16+
@pytest.mark.parametrize("array_namespace", ["cupy.array_api", "numpy.array_api"])
17+
def test_get_namespace_returns_actual_namespace(array_namespace):
18+
xp = pytest.importorskip(array_namespace)
19+
X = xp.asarray([1, 2, 3])
20+
xp_ = array_api_compat.get_namespace(X)
21+
assert xp_ is xp

0 commit comments

Comments
 (0)