Skip to content

Commit 1eace5d

Browse files
committed
Update the function stubs with device support
1 parent d80641c commit 1eace5d

File tree

3 files changed

+17
-14
lines changed

3 files changed

+17
-14
lines changed

array_api_tests/function_stubs/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99

1010
__all__ = []
1111

12-
from .array_object import __abs__, __add__, __and__, __eq__, __floordiv__, __ge__, __getitem__, __gt__, __invert__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, dtype, ndim, shape, size, T
12+
from .array_object import __abs__, __add__, __and__, __eq__, __floordiv__, __ge__, __getitem__, __gt__, __invert__, __le__, __len__, __lshift__, __lt__, __matmul__, __mod__, __mul__, __ne__, __neg__, __or__, __pos__, __pow__, __rshift__, __setitem__, __sub__, __truediv__, __xor__, dtype, device, ndim, shape, size, T
1313

14-
__all__ += ['__abs__', '__add__', '__and__', '__eq__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'ndim', 'shape', 'size', 'T']
14+
__all__ += ['__abs__', '__add__', '__and__', '__eq__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']
1515

1616
from .constants import e, inf, nan, pi
1717

array_api_tests/function_stubs/array_object.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def __xor__(x1, x2):
173173
# Note: dtype is an attribute of the array object.
174174
dtype = None
175175

176+
# Note: device is an attribute of the array object.
177+
device = None
178+
176179
# Note: ndim is an attribute of the array object.
177180
ndim = None
178181

@@ -185,4 +188,4 @@ def __xor__(x1, x2):
185188
# Note: T is an attribute of the array object.
186189
T = None
187190

188-
__all__ = ['__abs__', '__add__', '__and__', '__eq__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'ndim', 'shape', 'size', 'T']
191+
__all__ = ['__abs__', '__add__', '__and__', '__eq__', '__floordiv__', '__ge__', '__getitem__', '__gt__', '__invert__', '__le__', '__len__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__mul__', '__ne__', '__neg__', '__or__', '__pos__', '__pow__', '__rshift__', '__setitem__', '__sub__', '__truediv__', '__xor__', 'dtype', 'device', 'ndim', 'shape', 'size', 'T']

array_api_tests/function_stubs/creation_functions.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,37 +14,37 @@
1414
2. There is no real way to test that anyway.
1515
"""
1616

17-
def arange(start, *, stop=None, step=1, dtype=None):
17+
def arange(start, *, stop=None, step=1, dtype=None, device=None):
1818
pass
1919

20-
def empty(shape, *, dtype=None):
20+
def empty(shape, *, dtype=None, device=None):
2121
pass
2222

23-
def empty_like(x, *, dtype=None):
23+
def empty_like(x, *, dtype=None, device=None):
2424
pass
2525

26-
def eye(N, *, M=None, k=0, dtype=None):
26+
def eye(N, *, M=None, k=0, dtype=None, device=None):
2727
pass
2828

29-
def full(shape, fill_value, *, dtype=None):
29+
def full(shape, fill_value, *, dtype=None, device=None):
3030
pass
3131

32-
def full_like(x, fill_value, *, dtype=None):
32+
def full_like(x, fill_value, *, dtype=None, device=None):
3333
pass
3434

35-
def linspace(start, stop, num, *, dtype=None, endpoint=True):
35+
def linspace(start, stop, num, *, dtype=None, device=None, endpoint=True):
3636
pass
3737

38-
def ones(shape, *, dtype=None):
38+
def ones(shape, *, dtype=None, device=None):
3939
pass
4040

41-
def ones_like(x, *, dtype=None):
41+
def ones_like(x, *, dtype=None, device=None):
4242
pass
4343

44-
def zeros(shape, *, dtype=None):
44+
def zeros(shape, *, dtype=None, device=None):
4545
pass
4646

47-
def zeros_like(x, *, dtype=None):
47+
def zeros_like(x, *, dtype=None, device=None):
4848
pass
4949

5050
__all__ = ['arange', 'empty', 'empty_like', 'eye', 'full', 'full_like', 'linspace', 'ones', 'ones_like', 'zeros', 'zeros_like']

0 commit comments

Comments
 (0)