Skip to content

Commit 938bd8e

Browse files
committed
Lift Subtensor over Softmax
1 parent 43cad30 commit 938bd8e

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pytensor import Variable
66
from pytensor.graph import Constant, node_rewriter
77
from pytensor.graph.rewriting.basic import copy_stack_trace
8-
from pytensor.npy_2_compat import normalize_axis_tuple
8+
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
99
from pytensor.scalar import basic as ps
1010
from pytensor.tensor.basic import (
1111
Alloc,
@@ -32,6 +32,7 @@
3232
SpecifyShape,
3333
specify_shape,
3434
)
35+
from pytensor.tensor.special import Softmax, softmax
3536
from pytensor.tensor.subtensor import (
3637
AdvancedSubtensor1,
3738
Subtensor,
@@ -51,6 +52,20 @@ def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]
5152
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
5253

5354

55+
def _ndim_dropped_left_of_axis_by_basic_index(
56+
idxs: Sequence[slice | int], axis: int
57+
) -> int:
58+
return len(_dims_dropped_by_basic_index(idxs[:axis]))
59+
60+
61+
def _axis_is_indexed_by_basic_index(
62+
idxs: Sequence[slice | int], axis: int | Sequence[int]
63+
) -> bool:
64+
if isinstance(axis, int):
65+
axis = (axis,)
66+
return any(ax < len(idxs) and not is_full_slice(idxs[ax]) for ax in axis)
67+
68+
5469
@register_canonicalize
5570
@register_stabilize
5671
@register_specialize
@@ -241,6 +256,84 @@ def local_subtensor_of_reduce(fgraph, node):
241256
return [out]
242257

243258

259+
@register_canonicalize
260+
@register_specialize
261+
@node_rewriter([Subtensor])
262+
def local_subtensor_of_softmax(fgraph, node):
263+
"""Lift a Subtensor through a Softmax.
264+
265+
softmax(x, axis=1)[0] -> softmax(x[0], axis=0)
266+
softmax(x, axis=1)[:, :, 0] -> softmax(x[:, :, 0], axis=1)
267+
268+
If part of the indexing acts on the axis of reduction, we split it
269+
softmax(x, axis=1)[:, 0, 1:] -> softmax(x[:, :, 1:], axis=1)[0]
270+
271+
"""
272+
sm, *idx = node.inputs
273+
274+
if not (sm.owner and isinstance(sm.owner.op, Softmax)):
275+
return None
276+
277+
if len(fgraph.clients[sm]) > 1:
278+
return None
279+
280+
[x] = sm.owner.inputs
281+
axis = sm.owner.op.axis
282+
283+
if axis is None:
284+
if x.type.ndim == 1:
285+
axis = 0
286+
else:
287+
# All dimensions are mixed, we can't lift the subtensor
288+
return None
289+
else:
290+
# Softmax currently only allows None or a single integer axis
291+
# Unlike CAReduce it does not normalize negative indices
292+
axis = normalize_axis_index(axis, sm.ndim)
293+
294+
[old_out] = node.outputs
295+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
296+
297+
if _axis_is_indexed_by_basic_index(idx_tuple, axis):
298+
# If there are more dimensions being indexed, we can split them
299+
# And lift the non-axis indexes while keeping the axis index
300+
real_indices = [idx for idx in idx_tuple if not is_full_slice(idx)]
301+
if len(real_indices) > 1 and sm.type.ndim > 1:
302+
# Split the subtensor
303+
idx_to_keep = idx_tuple[axis]
304+
idxs_to_lift = (*idx_tuple[:axis], slice(None), *idx_tuple[axis + 1 :])
305+
306+
# Lift the non-axis indexes by calling the rewrite itself
307+
opt_sm = sm[idxs_to_lift]
308+
[opt_sm] = local_subtensor_of_softmax.transform(fgraph, opt_sm.owner)
309+
copy_stack_trace([old_out, sm], opt_sm)
310+
311+
# Then reintroduce the axis index
312+
ndim_reduced_left = _ndim_dropped_left_of_axis_by_basic_index(
313+
idx_tuple, axis
314+
)
315+
new_axis = axis - ndim_reduced_left
316+
idxs_to_keep = (*(slice(None),) * new_axis, idx_to_keep)
317+
new_out = opt_sm[idxs_to_keep]
318+
copy_stack_trace(old_out, new_out)
319+
return [new_out]
320+
321+
else:
322+
return None
323+
324+
# Index input to softmax
325+
x_sub = x[idx_tuple]
326+
327+
# Adjust axis of reduction when indexing drops dimensions (integer indexing as apposed to slice indexing)
328+
axis -= len(
329+
[idx_item for idx_item in idx_tuple[:axis] if not isinstance(idx_item, slice)]
330+
)
331+
332+
out = softmax(x_sub, axis=axis)
333+
copy_stack_trace(old_out, out)
334+
return [out]
335+
336+
244337
@register_canonicalize("shape_unsafe")
245338
@register_specialize("shape_unsafe")
246339
@node_rewriter([Subtensor])

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
local_subtensor_shape_constant,
4646
)
4747
from pytensor.tensor.shape import SpecifyShape, _shape
48+
from pytensor.tensor.special import softmax
4849
from pytensor.tensor.subtensor import Subtensor
4950

5051

@@ -211,6 +212,44 @@ def test_local_subtensor_of_reduce(original_fn, expected_fn):
211212
)
212213

213214

215+
@pytest.mark.parametrize(
216+
"original_fn, expected_fn",
217+
[
218+
# Lift single index that does not ovelap with axis of softmax
219+
(lambda x: softmax(x, axis=1)[0], lambda x: softmax(x[0], axis=0)),
220+
(lambda x: softmax(x, axis=1)[1:], lambda x: softmax(x[1:], axis=1)),
221+
(lambda x: softmax(x, axis=0)[:, 0], lambda x: softmax(x[:, 0], axis=0)),
222+
(lambda x: softmax(x, axis=0)[:, 1:], lambda x: softmax(x[:, 1:], axis=0)),
223+
# Do nothing to single index over axis of softmax
224+
(lambda x: softmax(x, axis=0)[0], lambda x: softmax(x, axis=0)[0]),
225+
(lambda x: softmax(x, axis=1)[:, 1:], lambda x: softmax(x, axis=1)[:, 1:]),
226+
# Split indexing on axis of softmax
227+
(lambda x: softmax(x, axis=0)[1:, 0], lambda x: softmax(x[:, 0], axis=0)[1:]),
228+
(lambda x: softmax(x, axis=1)[1:, 0], lambda x: softmax(x[1:], axis=1)[:, 0]),
229+
(
230+
lambda x: softmax(x, axis=0)[0, :5:2],
231+
lambda x: softmax(x[:, :5:2], axis=0)[0],
232+
),
233+
(lambda x: softmax(x, axis=1)[0, :5:2], lambda x: softmax(x[0], axis=0)[:5:2]),
234+
],
235+
)
236+
def test_local_subtensor_of_softmax(original_fn, expected_fn):
237+
rng = np.random.default_rng(230)
238+
x = pt.matrix("x", shape=(5, 3))
239+
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
240+
241+
out = original_fn(x)
242+
expected_opt_out = expected_fn(x)
243+
opt_out = rewrite_graph(out)
244+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
245+
[expected_opt_out, opt_out], print_type=True
246+
)
247+
np.testing.assert_allclose(
248+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
249+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
250+
)
251+
252+
214253
@pytest.mark.parametrize(
215254
"original_fn, expected_fn",
216255
[

0 commit comments

Comments
 (0)