Skip to content

Commit f1db1bd

Browse files
committed
Lift Subtensor over transpose
1 parent db7b988 commit f1db1bd

File tree

2 files changed

+88
-2
lines changed

2 files changed

+88
-2
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Iterable
1+
from collections.abc import Iterable, Sequence
22

33
import numpy as np
44

@@ -17,12 +17,14 @@
1717
)
1818
from pytensor.tensor.elemwise import DimShuffle, Elemwise
1919
from pytensor.tensor.exceptions import NotScalarConstantError
20+
from pytensor.tensor.extra_ops import squeeze
2021
from pytensor.tensor.math import Dot, ceil_intdiv, dot
2122
from pytensor.tensor.rewriting.basic import (
2223
register_canonicalize,
2324
register_specialize,
2425
register_stabilize,
2526
)
27+
from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift
2628
from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
2729
from pytensor.tensor.shape import (
2830
Shape,
@@ -42,6 +44,12 @@
4244
from pytensor.tensor.type_other import SliceType
4345

4446

47+
def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]:
48+
# Inputs can be slice or integer indexes
49+
# Slices keep the dimensions, integers collapse them
50+
return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice))
51+
52+
4553
@register_canonicalize
4654
@register_stabilize
4755
@register_specialize
@@ -243,6 +251,55 @@ def local_subtensor_of_expand_dims(fgraph, node):
243251
return [out]
244252

245253

254+
@register_canonicalize
255+
@register_specialize
256+
@node_rewriter([Subtensor])
257+
def local_subtensor_of_transpose(fgraph, node):
258+
"""Lift a Subtensor through a DimShuffle that only transposes.
259+
260+
transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2))
261+
"""
262+
ds, *idx = node.inputs
263+
264+
if not (ds.owner and isinstance(ds.owner.op, DimShuffle)):
265+
return None
266+
267+
ds_op = ds.owner.op
268+
if not ds_op.is_transpose:
269+
return None
270+
271+
transposition = ds_op.transposition
272+
[x] = ds.owner.inputs
273+
274+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
275+
276+
# Apply the transposition to the indexes
277+
ndim = x.type.ndim
278+
n_implicit_idxs = ndim - len(idx_tuple)
279+
idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs
280+
new_idxs = [idx_tuple[transposition.index(i)] for i in range(ndim)]
281+
new_x = x[tuple(new_idxs)]
282+
283+
# Reintroduce any dims dropped by indexing so the original transpose still works
284+
dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs)
285+
if dims_dropped_by_new_idx:
286+
new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx)
287+
288+
# Apply the transpose
289+
new_out = ds_op(new_x)
290+
291+
# Squeeze dims again now that the transpose is done
292+
if dims_dropped_by_new_idx:
293+
dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple)
294+
new_out = squeeze(new_out, axis=dims_dropped_by_original_idx)
295+
296+
# Cleanup consecutive expand_dims / transpose / squeeze (if any)
297+
if dims_dropped_by_new_idx:
298+
[new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner)
299+
300+
return [new_out]
301+
302+
246303
@register_infer_shape
247304
@register_useless
248305
@register_canonicalize

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
252252

253253
out = original_fn(x)
254254
expected_opt_out = expected_fn(x)
255-
opt_out = rewrite_graph(out, exclude=["local_uint_constant_indices"])
255+
opt_out = rewrite_graph(out)
256256
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
257257
[opt_out, expected_opt_out], print_type=True
258258
)
@@ -262,6 +262,35 @@ def test_local_subtensor_of_expand_dims(original_fn, expected_fn):
262262
)
263263

264264

265+
@pytest.mark.parametrize(
266+
"original_fn, expected_fn",
267+
[
268+
(lambda x: x.transpose(2, 1, 0)[0], lambda x: x[:, :, 0].transpose(1, 0)),
269+
(lambda x: x.transpose(2, 1, 0)[:, :, 1:], lambda x: x[1:].transpose(2, 1, 0)),
270+
(
271+
lambda x: x.transpose(2, 1, 0)[0, :1, 1:],
272+
lambda x: x[1:, :1, 0].transpose(1, 0),
273+
),
274+
(lambda x: x.transpose(2, 1, 0)[0, :1, 1], lambda x: x[1, :1, 0]),
275+
],
276+
)
277+
def test_local_subtensor_of_transpose(original_fn, expected_fn):
278+
rng = np.random.default_rng(232)
279+
x = tensor("x", shape=(7, 5, 3))
280+
x_test = rng.normal(size=x.type.shape).astype(x.dtype)
281+
282+
out = original_fn(x)
283+
expected_opt_out = expected_fn(x)
284+
opt_out = rewrite_graph(out)
285+
assert equal_computations([opt_out], [expected_opt_out]), debugprint(
286+
[expected_opt_out, opt_out], print_type=True
287+
)
288+
np.testing.assert_allclose(
289+
opt_out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
290+
out.eval({x: x_test}, mode=NO_OPTIMIZATION_MODE),
291+
)
292+
293+
265294
def test_local_subtensor_of_alloc():
266295
# DebugMode should detect if something goes wrong.
267296
# test shape combination of odd and event shape.

0 commit comments

Comments
 (0)