Skip to content

Commit c30b59e

Browse files
committed
Adds pytorch to get_namespace
1 parent b143d02 commit c30b59e

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

array_api_compat/common/_helpers.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ def get_namespace(*xs, _use_compat=True):
7575
else:
7676
import cupy as cp
7777
namespaces.add(cp)
78+
elif _is_torch_array(x):
79+
if _use_compat:
80+
from .. import torch as torch_namespace
81+
namespaces.add(torch_namespace)
82+
else:
83+
import torch
84+
namespaces.add(torch)
7885
else:
7986
# TODO: Support Python scalars?
8087
raise ValueError("The input is not a supported array type")

tests/test_get_namespace.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import array_api_compat
2+
import pytest
3+
4+
5+
@pytest.mark.parametrize("library", ["cupy", "numpy", "torch"])
6+
def test_get_namespace(library):
7+
lib = pytest.importorskip(library)
8+
9+
array = lib.asarray([1.0, 2.0, 3.0])
10+
namespace = array_api_compat.get_namespace(array)
11+
12+
expected_namespace = getattr(array_api_compat, library)
13+
assert namespace is expected_namespace
14+

0 commit comments

Comments
 (0)