Skip to content

Commit 03e1ae7

Browse files
committed
Add multi-device support to sorting functions
1 parent 405b7e7 commit 03e1ae7

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

array_api_strict/_sorting_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def argsort(
3333
normalised_axis = axis if axis >= 0 else x.ndim + axis
3434
max_i = x.shape[normalised_axis] - 1
3535
res = max_i - res
36-
return Array._new(res)
36+
return Array._new(res, device=x.device)
3737

3838
# Note: the descending keyword argument is new in this function
3939
def sort(
@@ -51,4 +51,4 @@ def sort(
5151
res = np.sort(x._array, axis=axis, kind=kind)
5252
if descending:
5353
res = np.flip(res, axis=axis)
54-
return Array._new(res)
54+
return Array._new(res, device=x.device)

array_api_strict/tests/test_sorting_functions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,17 @@ def test_stable_desc_argsort(obj, axis, expected):
2121
x = xp.asarray(obj)
2222
out = xp.argsort(x, axis=axis, stable=True, descending=True)
2323
assert xp.all(out == xp.asarray(expected))
24+
25+
26+
def test_argsort_device():
27+
x = xp.asarray([1., 2., -1., 3.141], device=xp.Device("device1"))
28+
y = xp.argsort(x)
29+
30+
assert y.device == x.device
31+
32+
33+
def test_sort_device():
34+
x = xp.asarray([1., 2., -1., 3.141], device=xp.Device("device1"))
35+
y = xp.sort(x)
36+
37+
assert y.device == x.device

0 commit comments

Comments
 (0)