File tree Expand file tree Collapse file tree 2 files changed +37
-2
lines changed Expand file tree Collapse file tree 2 files changed +37
-2
lines changed Original file line number Diff line number Diff line change @@ -19,7 +19,7 @@ def is_array_api_obj(x):
19
19
"""
20
20
return _is_numpy_array (x ) or hasattr (x , '__array_namespace__' )
21
21
22
- def get_namespace (* xs ):
22
+ def get_namespace (* xs , _use_compat = True ):
23
23
"""
24
24
Get the array API compatible namespace for the arrays `xs`.
25
25
@@ -30,7 +30,10 @@ def get_namespace(*xs):
30
30
if hasattr (x , '__array_namespace__' ):
31
31
namespaces .add (x .__array_namespace__ )
32
32
elif _is_numpy_array (x ):
33
- namespaces .add (compat_namespace )
33
+ if _use_compat :
34
+ namespaces .add (compat_namespace )
35
+ else :
36
+ namespaces .add (np )
34
37
else :
35
38
# TODO: Support Python scalars?
36
39
raise ValueError ("The input is not a supported array type" )
Original file line number Diff line number Diff line change
1
+ """
2
+ Internal helpers
3
+ """
4
+
5
+ from functools import wraps
6
+ from inspect import signature
7
+
8
+ from ._helpers import get_namespace
9
+
10
+ def get_xp (f ):
11
+ """
12
+ Decorator to automatically replace xp with the corresponding array module
13
+
14
+ Use like
15
+
16
+ @get_xp
17
+ def func(x, /, xp, kwarg=None):
18
+ return xp.func(x, kwarg=kwarg)
19
+
20
+ Note that xp must be able to be passed as a keyword argument.
21
+ """
22
+ @wraps (f )
23
+ def inner (* args , ** kwargs ):
24
+ xp = get_namespace (* args , _use_compat = False )
25
+ return f (* args , xp = xp , ** kwargs )
26
+
27
+ sig = signature (f )
28
+ new_sig = sig .replace (parameters = [sig .parameters [i ] for i in sig .parameters if i != 'xp' ])
29
+
30
+ inner .__signature__ = new_sig
31
+
32
+ return inner
You can’t perform that action at this time.
0 commit comments