Skip to content

Commit 2307d87

Browse files
committed
Remove unused or unreachable code in tensor/sort.py
1 parent 7159215 commit 2307d87

File tree

1 file changed

+9
-105
lines changed

1 file changed

+9
-105
lines changed

pytensor/tensor/sort.py

Lines changed: 9 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,14 @@
11
import numpy as np
22

33
from pytensor.gradient import grad_undefined
4-
from pytensor.graph.basic import Apply, Constant
4+
from pytensor.graph.basic import Apply
55
from pytensor.graph.op import Op
66
from pytensor.misc.safe_asarray import _asarray
77
from pytensor.tensor.basic import arange, as_tensor_variable, switch
8-
from pytensor.tensor.math import eq, ge, mul
8+
from pytensor.tensor.math import eq, ge
99
from pytensor.tensor.type import TensorType
1010

1111

12-
def _variable_is_none(var):
13-
return isinstance(var, Constant) and var.data is None
14-
15-
16-
def _check_tensor_is_scalar(var):
17-
"""
18-
Checks if a tensor variable is scalar, raise ValueError otherwise
19-
"""
20-
msg = "%(var)s is expected to be 0d tensor, got %(ndim)d"
21-
if var.ndim != 0:
22-
raise ValueError(msg % (var, var.ndim))
23-
24-
2512
class SortOp(Op):
2613
"""
2714
This class is a wrapper for numpy sort function.
@@ -39,28 +26,16 @@ def __str__(self):
3926

4027
def make_node(self, input, axis=-1):
4128
input = as_tensor_variable(input)
42-
axis = as_tensor_variable(axis)
29+
axis = as_tensor_variable(axis, ndim=0, dtype=int)
4330
out_type = input.type()
4431
return Apply(self, [input, axis], [out_type])
4532

4633
def perform(self, node, inputs, output_storage):
47-
a = inputs[0]
48-
axis = inputs[1]
49-
if axis is not None:
50-
if axis != int(axis):
51-
raise ValueError("sort axis must be an integer or None")
52-
axis = int(axis)
34+
a, axis = inputs
5335
z = output_storage[0]
54-
z[0] = np.sort(a, axis, self.kind, self.order)
36+
z[0] = np.sort(a, int(axis), self.kind, self.order)
5537

5638
def infer_shape(self, fgraph, node, inputs_shapes):
57-
if _variable_is_none(node.inputs[1]):
58-
# That means axis = None,
59-
# So the array is flattened before being sorted
60-
return [(mul(*inputs_shapes[0]),)]
61-
# axis should not be None
62-
# So there should be the same number of dimensions
63-
# in the input and output
6439
assert node.inputs[0].ndim == node.outputs[0].ndim
6540
assert inputs_shapes[1] == ()
6641
return [inputs_shapes[0]]
@@ -172,30 +147,22 @@ def __str__(self):
172147

173148
def make_node(self, input, axis=-1):
174149
input = as_tensor_variable(input)
175-
axis = as_tensor_variable(axis)
150+
axis = as_tensor_variable(axis, ndim=0, dtype=int)
176151
return Apply(
177152
self,
178153
[input, axis],
179154
[TensorType(dtype="int64", shape=input.type.shape)()],
180155
)
181156

182157
def perform(self, node, inputs, output_storage):
183-
a = inputs[0]
184-
axis = inputs[1]
185-
if axis is not None:
186-
if axis != int(axis):
187-
raise ValueError("sort axis must be an integer or None")
188-
axis = int(axis)
158+
a, axis = inputs
189159
z = output_storage[0]
190160
z[0] = _asarray(
191-
np.argsort(a, axis, self.kind, self.order), dtype=node.outputs[0].dtype
161+
np.argsort(a, int(axis), self.kind, self.order),
162+
dtype=node.outputs[0].dtype,
192163
)
193164

194165
def infer_shape(self, fgraph, node, inputs_shapes):
195-
if _variable_is_none(node.inputs[1]):
196-
return [(mul(*inputs_shapes[0]),)]
197-
# axis should not be None, so there should be the same number of
198-
# dimensions in the input and output
199166
assert node.inputs[0].ndim == node.outputs[0].ndim
200167
assert inputs_shapes[1] == ()
201168
return [inputs_shapes[0]]
@@ -239,66 +206,3 @@ def argsort(a, axis=-1, kind="quicksort", order=None):
239206
a = a.flatten()
240207
axis = 0
241208
return ArgSortOp(kind, order)(a, axis)
242-
243-
244-
def _topk_py_impl(op, x, k, axis, idx_dtype):
245-
ndim = x.ndim
246-
assert -ndim <= axis < ndim
247-
axis %= ndim
248-
if k == 0:
249-
raise ValueError("topk: kth cannot be zero")
250-
elif k > x.shape[axis]:
251-
raise ValueError(
252-
f"topk: kth cannot be larger than the size of specified axis {int(axis)}"
253-
)
254-
if abs(k) == 1:
255-
# negative k means min instead of max
256-
fn_max = [None, np.max, np.min][k]
257-
fn_argmax = [None, np.argmax, np.argmin][k]
258-
if not op.return_indices:
259-
return np.expand_dims(fn_max(x, axis=axis), axis)
260-
elif op.return_values:
261-
zi = np.expand_dims(fn_argmax(x, axis=axis), axis)
262-
idx2 = tuple(
263-
np.arange(s).reshape((s,) + (1,) * (ndim - i - 1)) if i != axis else zi
264-
for i, s in enumerate(x.shape)
265-
)
266-
zv = x[idx2]
267-
return zv, zi.astype(idx_dtype)
268-
else:
269-
zi = np.expand_dims(fn_argmax(x, axis=axis), axis)
270-
return zi.astype(idx_dtype)
271-
272-
if x.shape[axis] == abs(k):
273-
if not op.return_indices:
274-
return x.copy()
275-
else:
276-
l = axis
277-
r = ndim - l
278-
reps = list(x.shape)
279-
reps[axis] = 1
280-
zi = np.arange(abs(k), dtype=idx_dtype)
281-
zi = zi.reshape((1,) * l + (k,) + (1,) * (r - 1))
282-
zi = np.tile(zi, reps)
283-
if op.return_values:
284-
return x.copy(), zi
285-
else:
286-
return zi
287-
288-
idx = [slice(None)] * ndim
289-
idx[axis] = slice(-k, None) if k > 0 else slice(-k)
290-
291-
if not op.return_indices:
292-
zv = np.partition(x, -k, axis=axis)[tuple(idx)]
293-
return zv
294-
elif op.return_values:
295-
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
296-
idx2 = tuple(
297-
np.arange(s).reshape((s,) + (1,) * (ndim - i - 1)) if i != axis else zi
298-
for i, s in enumerate(x.shape)
299-
)
300-
zv = x[idx2]
301-
return zv, zi.astype(idx_dtype)
302-
else:
303-
zi = np.argpartition(x, -k, axis=axis)[tuple(idx)]
304-
return zi.astype(idx_dtype)

0 commit comments

Comments
 (0)