Skip to content

Commit 3dd1f80

Browse files
committed
Remove useless order kwarg from Sort and add numpy 2.0 stable kwarg
1 parent 35f4706 commit 3dd1f80

File tree

3 files changed

+75
-38
lines changed

3 files changed

+75
-38
lines changed

pytensor/link/jax/dispatch/sort.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
@jax_funcify.register(SortOp)
88
def jax_funcify_Sort(op, **kwargs):
9+
stable = op.kind == "stable"
10+
911
def sort(arr, axis):
10-
return jnp.sort(arr, axis=axis)
12+
return jnp.sort(arr, axis=axis, stable=stable)
1113

1214
return sort

pytensor/tensor/sort.py

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import typing
2+
13
import numpy as np
24

35
from pytensor.gradient import grad_undefined
@@ -9,20 +11,34 @@
911
from pytensor.tensor.type import TensorType
1012

1113

14+
KIND = typing.Literal["quicksort", "mergesort", "heapsort", "stable"]
15+
KIND_VALUES = typing.get_args(KIND)
16+
17+
18+
def _parse_sort_args(kind: KIND | None, order, stable: bool | None) -> KIND:
19+
if order is not None:
20+
raise ValueError("The order argument is not applicable to PyTensor graphs")
21+
if stable is not None and kind is not None:
22+
raise ValueError("kind and stable cannot be set at the same time")
23+
if stable:
24+
kind = "stable"
25+
elif kind is None:
26+
kind = "quicksort"
27+
if kind not in KIND_VALUES:
28+
raise ValueError(f"kind must be one of {KIND_VALUES}, got {kind}")
29+
return kind
30+
31+
1232
class SortOp(Op):
1333
"""
1434
This class is a wrapper for numpy sort function.
1535
1636
"""
1737

18-
__props__ = ("kind", "order")
38+
__props__ = ("kind",)
1939

20-
def __init__(self, kind, order=None):
40+
def __init__(self, kind: KIND):
2141
self.kind = kind
22-
self.order = order
23-
24-
def __str__(self):
25-
return self.__class__.__name__ + f"{{{self.kind}, {self.order}}}"
2642

2743
def make_node(self, input, axis=-1):
2844
input = as_tensor_variable(input)
@@ -33,7 +49,7 @@ def make_node(self, input, axis=-1):
3349
def perform(self, node, inputs, output_storage):
3450
a, axis = inputs
3551
z = output_storage[0]
36-
z[0] = np.sort(a, int(axis), self.kind, self.order)
52+
z[0] = np.sort(a, int(axis), self.kind)
3753

3854
def infer_shape(self, fgraph, node, inputs_shapes):
3955
assert node.inputs[0].ndim == node.outputs[0].ndim
@@ -75,9 +91,9 @@ def __get_argsort_indices(self, a, axis):
7591

7692
# The goal is to get gradient wrt input from gradient
7793
# wrt sort(input, axis)
78-
idx = argsort(a, axis, kind=self.kind, order=self.order)
94+
idx = argsort(a, axis, kind=self.kind)
7995
# rev_idx is the reverse of previous argsort operation
80-
rev_idx = argsort(idx, axis, kind=self.kind, order=self.order)
96+
rev_idx = argsort(idx, axis, kind=self.kind)
8197
indices = []
8298
axis_data = switch(ge(axis.data, 0), axis.data, a.ndim + axis.data)
8399
for i in range(a.ndim):
@@ -101,7 +117,9 @@ def R_op(self, inputs, eval_points):
101117
"""
102118

103119

104-
def sort(a, axis=-1, kind="quicksort", order=None):
120+
def sort(
121+
a, axis=-1, kind: KIND | None = None, order=None, *, stable: bool | None = None
122+
):
105123
"""
106124
107125
Parameters
@@ -111,23 +129,25 @@ def sort(a, axis=-1, kind="quicksort", order=None):
111129
axis: TensorVariable
112130
Axis along which to sort. If None, the array is flattened before
113131
sorting.
114-
kind: {'quicksort', 'mergesort', 'heapsort'}, optional
115-
Sorting algorithm. Default is 'quicksort'.
132+
kind: {'quicksort', 'mergesort', 'heapsort' 'stable'}, optional
133+
Sorting algorithm. Default is 'quicksort' unless stable is defined.
116134
order: list, optional
117-
When `a` is a structured array, this argument specifies which
118-
fields to compare first, second, and so on. This list does not
119-
need to include all of the fields.
135+
For compatibility with numpy sort signature. Cannot be specified.
136+
stable: bool, optional
137+
Same as specifying kind = 'stable'. Cannot be specified at the same time as kind
120138
121139
Returns
122140
-------
123141
array
124142
A sorted copy of an array.
125143
126144
"""
145+
kind = _parse_sort_args(kind, order, stable)
146+
127147
if axis is None:
128148
a = a.flatten()
129149
axis = 0
130-
return SortOp(kind, order)(a, axis)
150+
return SortOp(kind)(a, axis)
131151

132152

133153
class ArgSortOp(Op):
@@ -136,14 +156,10 @@ class ArgSortOp(Op):
136156
137157
"""
138158

139-
__props__ = ("kind", "order")
159+
__props__ = ("kind",)
140160

141-
def __init__(self, kind, order=None):
161+
def __init__(self, kind: KIND):
142162
self.kind = kind
143-
self.order = order
144-
145-
def __str__(self):
146-
return self.__class__.__name__ + f"{{{self.kind}, {self.order}}}"
147163

148164
def make_node(self, input, axis=-1):
149165
input = as_tensor_variable(input)
@@ -158,7 +174,7 @@ def perform(self, node, inputs, output_storage):
158174
a, axis = inputs
159175
z = output_storage[0]
160176
z[0] = _asarray(
161-
np.argsort(a, int(axis), self.kind, self.order),
177+
np.argsort(a, int(axis), self.kind),
162178
dtype=node.outputs[0].dtype,
163179
)
164180

@@ -192,7 +208,9 @@ def R_op(self, inputs, eval_points):
192208
"""
193209

194210

195-
def argsort(a, axis=-1, kind="quicksort", order=None):
211+
def argsort(
212+
a, axis=-1, kind: KIND | None = None, order=None, stable: bool | None = None
213+
):
196214
"""
197215
Returns the indices that would sort an array.
198216
@@ -202,7 +220,8 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
202220
order.
203221
204222
"""
223+
kind = _parse_sort_args(kind, order, stable)
205224
if axis is None:
206225
a = a.flatten()
207226
axis = 0
208-
return ArgSortOp(kind, order)(a, axis)
227+
return ArgSortOp(kind)(a, axis)

tests/tensor/test_sort.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import pytest
23

34
import pytensor
45
from pytensor.tensor.sort import ArgSortOp, SortOp, argsort, sort
@@ -65,13 +66,12 @@ def test4(self):
6566
utt.assert_allclose(gv, gt)
6667

6768
def test5(self):
68-
a1 = SortOp("mergesort", [])
69-
a2 = SortOp("quicksort", [])
69+
a1 = SortOp("mergesort")
70+
a2 = SortOp("quicksort")
7071

71-
# All the below should give true
7272
assert a1 != a2
73-
assert a1 == SortOp("mergesort", [])
74-
assert a2 == SortOp("quicksort", [])
73+
assert a1 == SortOp("mergesort")
74+
assert a2 == SortOp("quicksort")
7575

7676
def test_None(self):
7777
a = dmatrix()
@@ -208,14 +208,11 @@ def test_argsort():
208208
utt.assert_allclose(gv, gt)
209209

210210
# Example 5
211-
a = dmatrix()
212-
axis = lscalar()
213-
a1 = ArgSortOp("mergesort", [])
214-
a2 = ArgSortOp("quicksort", [])
215-
# All the below should give true
211+
a1 = ArgSortOp("mergesort")
212+
a2 = ArgSortOp("quicksort")
216213
assert a1 != a2
217-
assert a1 == ArgSortOp("mergesort", [])
218-
assert a2 == ArgSortOp("quicksort", [])
214+
assert a1 == ArgSortOp("mergesort")
215+
assert a2 == ArgSortOp("quicksort")
219216

220217
# Example 6: Testing axis=None
221218
a = dmatrix()
@@ -237,3 +234,22 @@ def test_argsort_grad():
237234

238235
data = rng.random((2, 3, 3)).astype(pytensor.config.floatX)
239236
utt.verify_grad(lambda x: argsort(x, axis=2), [data])
237+
238+
239+
@pytest.mark.parametrize("func", (sort, argsort))
240+
def test_parse_sort_args(func):
241+
x = matrix("x")
242+
243+
assert func(x).owner.op.kind == "quicksort"
244+
assert func(x, stable=True).owner.op.kind == "stable"
245+
246+
with pytest.raises(ValueError, match="kind must be one of"):
247+
func(x, kind="hanoi")
248+
249+
with pytest.raises(
250+
ValueError, match="kind and stable cannot be set at the same time"
251+
):
252+
func(x, kind="quicksort", stable=True)
253+
254+
with pytest.raises(ValueError, match="order argument is not applicable"):
255+
func(x, order=[])

0 commit comments

Comments
 (0)