1
1
import numpy as np
2
2
3
3
from pytensor .gradient import grad_undefined
4
- from pytensor .graph .basic import Apply , Constant
4
+ from pytensor .graph .basic import Apply
5
5
from pytensor .graph .op import Op
6
6
from pytensor .misc .safe_asarray import _asarray
7
7
from pytensor .tensor .basic import arange , as_tensor_variable , switch
8
- from pytensor .tensor .math import eq , ge , mul
8
+ from pytensor .tensor .math import eq , ge
9
9
from pytensor .tensor .type import TensorType
10
10
11
11
12
- def _variable_is_none (var ):
13
- return isinstance (var , Constant ) and var .data is None
14
-
15
-
16
- def _check_tensor_is_scalar (var ):
17
- """
18
- Checks if a tensor variable is scalar, raise ValueError otherwise
19
- """
20
- msg = "%(var)s is expected to be 0d tensor, got %(ndim)d"
21
- if var .ndim != 0 :
22
- raise ValueError (msg % (var , var .ndim ))
23
-
24
-
25
12
class SortOp (Op ):
26
13
"""
27
14
This class is a wrapper for numpy sort function.
@@ -39,28 +26,16 @@ def __str__(self):
39
26
40
27
def make_node (self , input , axis = - 1 ):
41
28
input = as_tensor_variable (input )
42
- axis = as_tensor_variable (axis )
29
+ axis = as_tensor_variable (axis , ndim = 0 , dtype = int )
43
30
out_type = input .type ()
44
31
return Apply (self , [input , axis ], [out_type ])
45
32
46
33
def perform (self , node , inputs , output_storage ):
47
- a = inputs [0 ]
48
- axis = inputs [1 ]
49
- if axis is not None :
50
- if axis != int (axis ):
51
- raise ValueError ("sort axis must be an integer or None" )
52
- axis = int (axis )
34
+ a , axis = inputs
53
35
z = output_storage [0 ]
54
- z [0 ] = np .sort (a , axis , self .kind , self .order )
36
+ z [0 ] = np .sort (a , int ( axis ) , self .kind , self .order )
55
37
56
38
def infer_shape (self , fgraph , node , inputs_shapes ):
57
- if _variable_is_none (node .inputs [1 ]):
58
- # That means axis = None,
59
- # So the array is flattened before being sorted
60
- return [(mul (* inputs_shapes [0 ]),)]
61
- # axis should not be None
62
- # So there should be the same number of dimensions
63
- # in the input and output
64
39
assert node .inputs [0 ].ndim == node .outputs [0 ].ndim
65
40
assert inputs_shapes [1 ] == ()
66
41
return [inputs_shapes [0 ]]
@@ -172,30 +147,22 @@ def __str__(self):
172
147
173
148
def make_node (self , input , axis = - 1 ):
174
149
input = as_tensor_variable (input )
175
- axis = as_tensor_variable (axis )
150
+ axis = as_tensor_variable (axis , ndim = 0 , dtype = int )
176
151
return Apply (
177
152
self ,
178
153
[input , axis ],
179
154
[TensorType (dtype = "int64" , shape = input .type .shape )()],
180
155
)
181
156
182
157
def perform (self , node , inputs , output_storage ):
183
- a = inputs [0 ]
184
- axis = inputs [1 ]
185
- if axis is not None :
186
- if axis != int (axis ):
187
- raise ValueError ("sort axis must be an integer or None" )
188
- axis = int (axis )
158
+ a , axis = inputs
189
159
z = output_storage [0 ]
190
160
z [0 ] = _asarray (
191
- np .argsort (a , axis , self .kind , self .order ), dtype = node .outputs [0 ].dtype
161
+ np .argsort (a , int (axis ), self .kind , self .order ),
162
+ dtype = node .outputs [0 ].dtype ,
192
163
)
193
164
194
165
def infer_shape (self , fgraph , node , inputs_shapes ):
195
- if _variable_is_none (node .inputs [1 ]):
196
- return [(mul (* inputs_shapes [0 ]),)]
197
- # axis should not be None, so there should be the same number of
198
- # dimensions in the input and output
199
166
assert node .inputs [0 ].ndim == node .outputs [0 ].ndim
200
167
assert inputs_shapes [1 ] == ()
201
168
return [inputs_shapes [0 ]]
@@ -239,66 +206,3 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
239
206
a = a .flatten ()
240
207
axis = 0
241
208
return ArgSortOp (kind , order )(a , axis )
242
-
243
-
244
- def _topk_py_impl (op , x , k , axis , idx_dtype ):
245
- ndim = x .ndim
246
- assert - ndim <= axis < ndim
247
- axis %= ndim
248
- if k == 0 :
249
- raise ValueError ("topk: kth cannot be zero" )
250
- elif k > x .shape [axis ]:
251
- raise ValueError (
252
- f"topk: kth cannot be larger than the size of specified axis { int (axis )} "
253
- )
254
- if abs (k ) == 1 :
255
- # negative k means min instead of max
256
- fn_max = [None , np .max , np .min ][k ]
257
- fn_argmax = [None , np .argmax , np .argmin ][k ]
258
- if not op .return_indices :
259
- return np .expand_dims (fn_max (x , axis = axis ), axis )
260
- elif op .return_values :
261
- zi = np .expand_dims (fn_argmax (x , axis = axis ), axis )
262
- idx2 = tuple (
263
- np .arange (s ).reshape ((s ,) + (1 ,) * (ndim - i - 1 )) if i != axis else zi
264
- for i , s in enumerate (x .shape )
265
- )
266
- zv = x [idx2 ]
267
- return zv , zi .astype (idx_dtype )
268
- else :
269
- zi = np .expand_dims (fn_argmax (x , axis = axis ), axis )
270
- return zi .astype (idx_dtype )
271
-
272
- if x .shape [axis ] == abs (k ):
273
- if not op .return_indices :
274
- return x .copy ()
275
- else :
276
- l = axis
277
- r = ndim - l
278
- reps = list (x .shape )
279
- reps [axis ] = 1
280
- zi = np .arange (abs (k ), dtype = idx_dtype )
281
- zi = zi .reshape ((1 ,) * l + (k ,) + (1 ,) * (r - 1 ))
282
- zi = np .tile (zi , reps )
283
- if op .return_values :
284
- return x .copy (), zi
285
- else :
286
- return zi
287
-
288
- idx = [slice (None )] * ndim
289
- idx [axis ] = slice (- k , None ) if k > 0 else slice (- k )
290
-
291
- if not op .return_indices :
292
- zv = np .partition (x , - k , axis = axis )[tuple (idx )]
293
- return zv
294
- elif op .return_values :
295
- zi = np .argpartition (x , - k , axis = axis )[tuple (idx )]
296
- idx2 = tuple (
297
- np .arange (s ).reshape ((s ,) + (1 ,) * (ndim - i - 1 )) if i != axis else zi
298
- for i , s in enumerate (x .shape )
299
- )
300
- zv = x [idx2 ]
301
- return zv , zi .astype (idx_dtype )
302
- else :
303
- zi = np .argpartition (x , - k , axis = axis )[tuple (idx )]
304
- return zi .astype (idx_dtype )
0 commit comments