|
1 |
| -from collections.abc import Iterable |
| 1 | +from collections.abc import Iterable, Sequence |
2 | 2 |
|
3 | 3 | import numpy as np
|
4 | 4 |
|
|
17 | 17 | )
|
18 | 18 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
19 | 19 | from pytensor.tensor.exceptions import NotScalarConstantError
|
| 20 | +from pytensor.tensor.extra_ops import squeeze |
20 | 21 | from pytensor.tensor.math import Dot, ceil_intdiv, dot
|
21 | 22 | from pytensor.tensor.rewriting.basic import (
|
22 | 23 | register_canonicalize,
|
23 | 24 | register_specialize,
|
24 | 25 | register_stabilize,
|
25 | 26 | )
|
| 27 | +from pytensor.tensor.rewriting.elemwise import local_dimshuffle_lift |
26 | 28 | from pytensor.tensor.rewriting.subtensor import is_full_slice, register_useless
|
27 | 29 | from pytensor.tensor.shape import (
|
28 | 30 | Shape,
|
|
42 | 44 | from pytensor.tensor.type_other import SliceType
|
43 | 45 |
|
44 | 46 |
|
| 47 | +def _dims_dropped_by_basic_index(idxs: Sequence[slice | int]) -> tuple[int, ...]: |
| 48 | + # Inputs can be slice or integer indexes |
| 49 | + # Slices keep the dimensions, integers collapse them |
| 50 | + return tuple(i for i, idx in enumerate(idxs) if not isinstance(idx, slice)) |
| 51 | + |
| 52 | + |
45 | 53 | @register_canonicalize
|
46 | 54 | @register_stabilize
|
47 | 55 | @register_specialize
|
@@ -243,6 +251,55 @@ def local_subtensor_of_expand_dims(fgraph, node):
|
243 | 251 | return [out]
|
244 | 252 |
|
245 | 253 |
|
| 254 | +@register_canonicalize |
| 255 | +@register_specialize |
| 256 | +@node_rewriter([Subtensor]) |
| 257 | +def local_subtensor_of_transpose(fgraph, node): |
| 258 | + """Lift a Subtensor through a DimShuffle that only transposes. |
| 259 | +
|
| 260 | + transpose(x, (1, 0, 2))[i:, j:, k:] -> transpose(x[j:, i:, k:], (1, 0, 2)) |
| 261 | + """ |
| 262 | + ds, *idx = node.inputs |
| 263 | + |
| 264 | + if not (ds.owner and isinstance(ds.owner.op, DimShuffle)): |
| 265 | + return None |
| 266 | + |
| 267 | + ds_op = ds.owner.op |
| 268 | + if not ds_op.is_transpose: |
| 269 | + return None |
| 270 | + |
| 271 | + transposition = ds_op.transposition |
| 272 | + [x] = ds.owner.inputs |
| 273 | + |
| 274 | + idx_tuple = indices_from_subtensor(idx, node.op.idx_list) |
| 275 | + |
| 276 | + # Apply the transposition to the indexes |
| 277 | + ndim = x.type.ndim |
| 278 | + n_implicit_idxs = ndim - len(idx_tuple) |
| 279 | + idx_tuple = idx_tuple + (slice(None),) * n_implicit_idxs |
| 280 | + new_idxs = [idx_tuple[transposition.index(i)] for i in range(ndim)] |
| 281 | + new_x = x[tuple(new_idxs)] |
| 282 | + |
| 283 | + # Reintroduce any dims dropped by indexing so the original transpose still works |
| 284 | + dims_dropped_by_new_idx = _dims_dropped_by_basic_index(new_idxs) |
| 285 | + if dims_dropped_by_new_idx: |
| 286 | + new_x = expand_dims(new_x, axis=dims_dropped_by_new_idx) |
| 287 | + |
| 288 | + # Apply the transpose |
| 289 | + new_out = ds_op(new_x) |
| 290 | + |
| 291 | + # Squeeze dims again now that the transpose is done |
| 292 | + if dims_dropped_by_new_idx: |
| 293 | + dims_dropped_by_original_idx = _dims_dropped_by_basic_index(idx_tuple) |
| 294 | + new_out = squeeze(new_out, axis=dims_dropped_by_original_idx) |
| 295 | + |
| 296 | + # Cleanup consecutive expand_dims / transpose / squeeze (if any) |
| 297 | + if dims_dropped_by_new_idx: |
| 298 | + [new_out] = local_dimshuffle_lift.transform(fgraph, new_out.owner) |
| 299 | + |
| 300 | + return [new_out] |
| 301 | + |
| 302 | + |
246 | 303 | @register_infer_shape
|
247 | 304 | @register_useless
|
248 | 305 | @register_canonicalize
|
|
0 commit comments