Skip to content

Commit 9f2afe8

Browse files
committed
Add a get_xp decorator to support multiple array namespaces
1 parent 391b08b commit 9f2afe8

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

numpy_array_api_compat/_helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def is_array_api_obj(x):
1919
"""
2020
return _is_numpy_array(x) or hasattr(x, '__array_namespace__')
2121

22-
def get_namespace(*xs):
22+
def get_namespace(*xs, _use_compat=True):
2323
"""
2424
Get the array API compatible namespace for the arrays `xs`.
2525
@@ -30,7 +30,10 @@ def get_namespace(*xs):
3030
if hasattr(x, '__array_namespace__'):
3131
namespaces.add(x.__array_namespace__)
3232
elif _is_numpy_array(x):
33-
namespaces.add(compat_namespace)
33+
if _use_compat:
34+
namespaces.add(compat_namespace)
35+
else:
36+
namespaces.add(np)
3437
else:
3538
# TODO: Support Python scalars?
3639
raise ValueError("The input is not a supported array type")

numpy_array_api_compat/_internal.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
Internal helpers
3+
"""
4+
5+
from functools import wraps
6+
from inspect import signature
7+
8+
from ._helpers import get_namespace
9+
10+
def get_xp(f):
11+
"""
12+
Decorator to automatically replace xp with the corresponding array module
13+
14+
Use like
15+
16+
@get_xp
17+
def func(x, /, xp, kwarg=None):
18+
return xp.func(x, kwarg=kwarg)
19+
20+
Note that xp must be able to be passed as a keyword argument.
21+
"""
22+
@wraps(f)
23+
def inner(*args, **kwargs):
24+
xp = get_namespace(*args, _use_compat=False)
25+
return f(*args, xp=xp, **kwargs)
26+
27+
sig = signature(f)
28+
new_sig = sig.replace(parameters=[sig.parameters[i] for i in sig.parameters if i != 'xp'])
29+
30+
inner.__signature__ = new_sig
31+
32+
return inner

0 commit comments

Comments
 (0)