5
5
from pytensor import Variable
6
6
from pytensor .graph import Constant , node_rewriter
7
7
from pytensor .graph .rewriting .basic import copy_stack_trace
8
- from pytensor .npy_2_compat import normalize_axis_tuple
8
+ from pytensor .npy_2_compat import normalize_axis_index , normalize_axis_tuple
9
9
from pytensor .scalar import basic as ps
10
10
from pytensor .tensor .basic import (
11
11
Alloc ,
32
32
SpecifyShape ,
33
33
specify_shape ,
34
34
)
35
+ from pytensor .tensor .special import Softmax , softmax
35
36
from pytensor .tensor .subtensor import (
36
37
AdvancedSubtensor1 ,
37
38
Subtensor ,
@@ -51,6 +52,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
51
52
return tuple (i for i , idx in enumerate (idxs ) if not isinstance (idx , slice ))
52
53
53
54
55
+ def _ndim_dropped_left_of_axis_by_basic_index (
56
+ idxs : Sequence [slice | int ], axis : int
57
+ ) -> int :
58
+ return len (_dims_dropped_by_basic_index (idxs [:axis ]))
59
+
60
+
61
+ def _axis_is_indexed_by_basic_index (
62
+ idxs : Sequence [slice | int ], axis : int | Sequence [int ]
63
+ ) -> bool :
64
+ if isinstance (axis , int ):
65
+ axis = (axis ,)
66
+ return any (ax < len (idxs ) and not is_full_slice (idxs [ax ]) for ax in axis )
67
+
68
+
54
69
@register_canonicalize
55
70
@register_stabilize
56
71
@register_specialize
@@ -241,6 +256,84 @@ def local_subtensor_of_reduce(fgraph, node):
241
256
return [out ]
242
257
243
258
259
+ @register_canonicalize
260
+ @register_specialize
261
+ @node_rewriter ([Subtensor ])
262
+ def local_subtensor_of_softmax (fgraph , node ):
263
+ """Lift a Subtensor through a Softmax.
264
+
265
+ softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
266
+ softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
267
+
268
+ If part of the indexing acts on the axis of reduction, we split it
269
+ softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
270
+
271
+ """
272
+ sm , * idx = node .inputs
273
+
274
+ if not (sm .owner and isinstance (sm .owner .op , Softmax )):
275
+ return None
276
+
277
+ if len (fgraph .clients [sm ]) > 1 :
278
+ return None
279
+
280
+ [x ] = sm .owner .inputs
281
+ axis = sm .owner .op .axis
282
+
283
+ if axis is None :
284
+ if x .type .ndim == 1 :
285
+ axis = 0
286
+ else :
287
+ # All dimensions are mixed, we can't lift the subtensor
288
+ return None
289
+ else :
290
+ # Softmax currently only allows None or a single integer axis
291
+ # Unlike CAReduce it does not normalize negative indices
292
+ axis = normalize_axis_index (axis , sm .ndim )
293
+
294
+ [old_out ] = node .outputs
295
+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
296
+
297
+ if _axis_is_indexed_by_basic_index (idx_tuple , axis ):
298
+ # If there are more dimensions being indexed, we can split them
299
+ # And lift the non-axis indexes while keeping the axis index
300
+ real_indices = [idx for idx in idx_tuple if not is_full_slice (idx )]
301
+ if len (real_indices ) > 1 and sm .type .ndim > 1 :
302
+ # Split the subtensor
303
+ idx_to_keep = idx_tuple [axis ]
304
+ idxs_to_lift = (* idx_tuple [:axis ], slice (None ), * idx_tuple [axis + 1 :])
305
+
306
+ # Lift the non-axis indexes by calling the rewrite itself
307
+ opt_sm = sm [idxs_to_lift ]
308
+ [opt_sm ] = local_subtensor_of_softmax .transform (fgraph , opt_sm .owner )
309
+ copy_stack_trace ([old_out , sm ], opt_sm )
310
+
311
+ # Then reintroduce the axis index
312
+ ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index (
313
+ idx_tuple , axis
314
+ )
315
+ new_axis = axis - ndim_reduced_left
316
+ idxs_to_keep = (* (slice (None ),) * new_axis , idx_to_keep )
317
+ new_out = opt_sm [idxs_to_keep ]
318
+ copy_stack_trace (old_out , new_out )
319
+ return [new_out ]
320
+
321
+ else :
322
+ return None
323
+
324
+ # Index input to softmax
325
+ x_sub = x [idx_tuple ]
326
+
327
+ # Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
328
+ axis -= len (
329
+ [idx_item for idx_item in idx_tuple [:axis ] if not isinstance (idx_item , slice )]
330
+ )
331
+
332
+ out = softmax (x_sub , axis = axis )
333
+ copy_stack_trace (old_out , out )
334
+ return [out ]
335
+
336
+
244
337
@register_canonicalize ("shape_unsafe" )
245
338
@register_specialize ("shape_unsafe" )
246
339
@node_rewriter ([Subtensor ])
0 commit comments