1
+ import typing
2
+
1
3
import numpy as np
2
4
3
5
from pytensor .gradient import grad_undefined
9
11
from pytensor .tensor .type import TensorType
10
12
11
13
14
+ KIND = typing .Literal ["quicksort" , "mergesort" , "heapsort" , "stable" ]
15
+ KIND_VALUES = typing .get_args (KIND )
16
+
17
+
18
+ def _parse_sort_args (kind : KIND | None , order , stable : bool | None ) -> KIND :
19
+ if order is not None :
20
+ raise ValueError ("The order argument is not applicable to PyTensor graphs" )
21
+ if stable is not None and kind is not None :
22
+ raise ValueError ("kind and stable cannot be set at the same time" )
23
+ if stable :
24
+ kind = "stable"
25
+ elif kind is None :
26
+ kind = "quicksort"
27
+ if kind not in KIND_VALUES :
28
+ raise ValueError (f"kind must be one of { KIND_VALUES } , got { kind } " )
29
+ return kind
30
+
31
+
12
32
class SortOp (Op ):
13
33
"""
14
34
This class is a wrapper for numpy sort function.
15
35
16
36
"""
17
37
18
- __props__ = ("kind" , "order" )
38
+ __props__ = ("kind" ,)
19
39
20
- def __init__ (self , kind , order = None ):
40
+ def __init__ (self , kind : KIND ):
21
41
self .kind = kind
22
- self .order = order
23
-
24
- def __str__ (self ):
25
- return self .__class__ .__name__ + f"{{{ self .kind } , { self .order } }}"
26
42
27
43
def make_node (self , input , axis = - 1 ):
28
44
input = as_tensor_variable (input )
@@ -33,7 +49,7 @@ def make_node(self, input, axis=-1):
33
49
def perform (self , node , inputs , output_storage ):
34
50
a , axis = inputs
35
51
z = output_storage [0 ]
36
- z [0 ] = np .sort (a , int (axis ), self .kind , self . order )
52
+ z [0 ] = np .sort (a , int (axis ), self .kind )
37
53
38
54
def infer_shape (self , fgraph , node , inputs_shapes ):
39
55
assert node .inputs [0 ].ndim == node .outputs [0 ].ndim
@@ -75,9 +91,9 @@ def __get_argsort_indices(self, a, axis):
75
91
76
92
# The goal is to get gradient wrt input from gradient
77
93
# wrt sort(input, axis)
78
- idx = argsort (a , axis , kind = self .kind , order = self . order )
94
+ idx = argsort (a , axis , kind = self .kind )
79
95
# rev_idx is the reverse of previous argsort operation
80
- rev_idx = argsort (idx , axis , kind = self .kind , order = self . order )
96
+ rev_idx = argsort (idx , axis , kind = self .kind )
81
97
indices = []
82
98
axis_data = switch (ge (axis .data , 0 ), axis .data , a .ndim + axis .data )
83
99
for i in range (a .ndim ):
@@ -101,7 +117,9 @@ def R_op(self, inputs, eval_points):
101
117
"""
102
118
103
119
104
- def sort (a , axis = - 1 , kind = "quicksort" , order = None ):
120
+ def sort (
121
+ a , axis = - 1 , kind : KIND | None = None , order = None , * , stable : bool | None = None
122
+ ):
105
123
"""
106
124
107
125
Parameters
@@ -111,23 +129,25 @@ def sort(a, axis=-1, kind="quicksort", order=None):
111
129
axis: TensorVariable
112
130
Axis along which to sort. If None, the array is flattened before
113
131
sorting.
114
- kind: {'quicksort', 'mergesort', 'heapsort'}, optional
115
- Sorting algorithm. Default is 'quicksort'.
132
+ kind: {'quicksort', 'mergesort', 'heapsort' 'stable' }, optional
133
+ Sorting algorithm. Default is 'quicksort' unless stable is defined .
116
134
order: list, optional
117
- When `a` is a structured array, this argument specifies which
118
- fields to compare first, second, and so on. This list does not
119
- need to include all of the fields.
135
+ For compatibility with numpy sort signature. Cannot be specified.
136
+ stable: bool, optional
137
+ Same as specifying kind = 'stable'. Cannot be specified at the same time as kind
120
138
121
139
Returns
122
140
-------
123
141
array
124
142
A sorted copy of an array.
125
143
126
144
"""
145
+ kind = _parse_sort_args (kind , order , stable )
146
+
127
147
if axis is None :
128
148
a = a .flatten ()
129
149
axis = 0
130
- return SortOp (kind , order )(a , axis )
150
+ return SortOp (kind )(a , axis )
131
151
132
152
133
153
class ArgSortOp (Op ):
@@ -136,14 +156,10 @@ class ArgSortOp(Op):
136
156
137
157
"""
138
158
139
- __props__ = ("kind" , "order" )
159
+ __props__ = ("kind" ,)
140
160
141
- def __init__ (self , kind , order = None ):
161
+ def __init__ (self , kind : KIND ):
142
162
self .kind = kind
143
- self .order = order
144
-
145
- def __str__ (self ):
146
- return self .__class__ .__name__ + f"{{{ self .kind } , { self .order } }}"
147
163
148
164
def make_node (self , input , axis = - 1 ):
149
165
input = as_tensor_variable (input )
@@ -158,7 +174,7 @@ def perform(self, node, inputs, output_storage):
158
174
a , axis = inputs
159
175
z = output_storage [0 ]
160
176
z [0 ] = _asarray (
161
- np .argsort (a , int (axis ), self .kind , self . order ),
177
+ np .argsort (a , int (axis ), self .kind ),
162
178
dtype = node .outputs [0 ].dtype ,
163
179
)
164
180
@@ -192,7 +208,9 @@ def R_op(self, inputs, eval_points):
192
208
"""
193
209
194
210
195
- def argsort (a , axis = - 1 , kind = "quicksort" , order = None ):
211
+ def argsort (
212
+ a , axis = - 1 , kind : KIND | None = None , order = None , stable : bool | None = None
213
+ ):
196
214
"""
197
215
Returns the indices that would sort an array.
198
216
@@ -202,7 +220,8 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
202
220
order.
203
221
204
222
"""
223
+ kind = _parse_sort_args (kind , order , stable )
205
224
if axis is None :
206
225
a = a .flatten ()
207
226
axis = 0
208
- return ArgSortOp (kind , order )(a , axis )
227
+ return ArgSortOp (kind )(a , axis )
0 commit comments