Skip to content

Commit 36e44ea

Browse files
Allow removing permute pairs in addition to transpose pairs (#10501) (#10566)
Summary: Pull Request resolved: As titled. Gets us 27% better cycles on Activity Classification (at opt level 3). Can be improved further (when fused permutations are not an identity), task is T222295719 Differential Revision: D73619452 Co-authored-by: Thomas Jannaud <[email protected]>
1 parent 82c42b6 commit 36e44ea

File tree

4 files changed

+207
-127
lines changed

4 files changed

+207
-127
lines changed

backends/cadence/aot/compiler_utils.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,12 @@ def get_cascaded_ops(
109109
return nodes
110110

111111

112-
# Capture the effect of transpose op on incoming dimension order
113-
def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
112+
def get_transposed_dims(
113+
node: torch.fx.Node, dims: Optional[List[int]] = None
114+
) -> List[int]:
114115
"""
115-
Given a transpose node, and the incoming dimension ordering of the input
116-
tensor to the transpose node, return the net effect of transpose op on the
117-
dimension order.
116+
Applies the transposition as given by node onto the dimensions given in input
117+
e.g (1, 2) on [a, b, c, d] would return [a, c, b, d]
118118
"""
119119
assert node.target == exir_ops.edge.aten.transpose_copy.int
120120
# Assert that the dims is not empty
@@ -127,28 +127,22 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
127127
assert isinstance(transpose_dims1, int)
128128
dim0 = transpose_dims0 if transpose_dims0 >= 0 else transpose_dims0 + dim_len
129129
dim1 = transpose_dims1 if transpose_dims1 >= 0 else transpose_dims1 + dim_len
130-
# Perform transpose on dimmension ordering (dims)
131-
dims[dim0], dims[dim1] = dims[dim1], dims[dim0]
132-
return dims
130+
new_dims = list(dims)
131+
new_dims[dim0], new_dims[dim1] = dims[dim1], dims[dim0]
132+
return new_dims
133133

134134

135-
# Capture the effect of permute op on incoming dimension order
136-
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
135+
def get_permuted_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
137136
"""
138-
Given a permute node, and the incoming dimension ordering of the input
139-
tensor to the permute node, return the net effect of permute op on the
140-
dimension order.
137+
Applies the permutation as given by node onto the dimensions given in input
138+
e.g (1, 2, 0) on [a, b, c] would return [b, c, a]
141139
"""
142140
assert node.target == exir_ops.edge.aten.permute_copy.default
143141
# Permute each index of the dimension ordering (dims)
144142
# pyre-fixme[6]: This combined typecheck isn't supported yet.
145143
permute_dims: List[int] = list(node.args[1])
146144
assert all(isinstance(x, int) for x in permute_dims)
147-
# If the dims is empty, we can simply return the permute order
148-
if not dims:
149-
return permute_dims
150-
dims = [dims[x] for x in permute_dims]
151-
return dims
145+
return [dims[x] for x in permute_dims]
152146

153147

154148
# Return the tensor of buffer/parameter op

backends/cadence/aot/fuse_ops.py

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import operator
1515
from collections import deque
1616
from numbers import Number
17-
from typing import cast, Sequence
17+
from typing import Any, Callable, cast
1818

1919
# Import these for the cadence function signatures.
2020
import executorch.backends.cadence.aot.ops_registrations # noqa: F401
@@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881881

882882

883883
@register_cadence_pass(CadencePassAttribute(opt_level=1))
884-
class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass):
884+
class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886-
Fuse transpose op pairs to a single view op.
886+
Fuse transpose or permute op pairs to a single view op.
887+
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
887888
"""
888889

889890
# A list of ops that can be bypassed when looking for a
@@ -907,42 +908,17 @@ def can_fuse_for_chain(
907908
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
908909
return False
909910

910-
def get_dims(node: torch.fx.Node) -> tuple[int, int]:
911-
def canonicalize(dim: int) -> int:
912-
if dim < 0:
913-
dim += len(node.meta["val"].shape)
914-
return dim
915-
916-
return tuple(canonicalize(cast(int, d)) for d in node.args[1:3])
917-
918-
def is_equivalent(
919-
shape: Sequence[int],
920-
transpose0: tuple[int, int],
921-
transpose1: tuple[int, int],
922-
) -> bool:
923-
def permute_order(
924-
order: Sequence[int], dims: tuple[int, int]
925-
) -> Sequence[int]:
926-
new_order = list(order)
927-
new_order[dims[0]], new_order[dims[1]] = (
928-
new_order[dims[1]],
929-
new_order[dims[0]],
930-
)
931-
return new_order
932-
933-
order = permute_order(range(len(shape)), transpose0)
934-
order = permute_order(order, transpose1)
935-
936-
non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1]
937-
non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1]
938-
939-
return non_unit_dims == non_unit_dims_permuted
940-
941-
return is_equivalent(
942-
cast(torch.fx.Node, producer.args[0]).meta["val"].shape,
943-
get_dims(producer),
944-
get_dims(consumer),
945-
)
911+
# checking that permut2(permut1(identify)) == identity
912+
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
913+
ident_dims = list(range(len(input_shape)))
914+
# this mapping helps to handle both transpose and permutations
915+
f: dict[Any, Callable] = {
916+
exir_ops.edge.aten.transpose_copy.int: get_transposed_dims,
917+
exir_ops.edge.aten.permute_copy.default: get_permuted_dims,
918+
}
919+
in_dims = f[producer.target](producer, ident_dims)
920+
out_dims = f[consumer.target](consumer, in_dims)
921+
return out_dims == ident_dims
946922

947923
def get_fused_node(
948924
self,
@@ -960,11 +936,17 @@ def get_fused_node(
960936
return view
961937

962938
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
963-
# Remove any dequantize op that has only quantize ops as its users.
939+
# Remove any transpose/permutation op pair that cancel each other.
964940
self.find_and_fuse(
965941
graph_module,
966-
producer_op_packets={exir_ops.edge.aten.transpose_copy},
967-
consumer_op_packets={exir_ops.edge.aten.transpose_copy},
942+
producer_op_packets={
943+
exir_ops.edge.aten.transpose_copy,
944+
exir_ops.edge.aten.permute_copy,
945+
},
946+
consumer_op_packets={
947+
exir_ops.edge.aten.transpose_copy,
948+
exir_ops.edge.aten.permute_copy,
949+
},
968950
bypass_ops=self.bypass_ops,
969951
)
970952
result = super().call(graph_module)
@@ -1028,5 +1010,5 @@ class CadenceFuseOpsInGraph:
10281010
FuseQuantDequantToRequantizePass,
10291011
FuseMulIntoDequantPass,
10301012
FuseFullThenReshapePass,
1031-
FuseTransposeOpPairsPass,
1013+
FuseTransposeOrPermuteOpPairsPass,
10321014
]

backends/cadence/aot/passes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.backends.cadence.aot.fuse_ops import (
1515
CadenceFuseOpsInGraph,
1616
FuseFullThenReshapePass,
17-
FuseTransposeOpPairsPass,
17+
FuseTransposeOrPermuteOpPairsPass,
1818
)
1919
from executorch.backends.cadence.aot.pass_utils import (
2020
CadencePassAttribute,
@@ -83,7 +83,7 @@ def get_passes_in_default_order() -> List[ExportPass]:
8383
CadenceSimplifyOpsInGraph.passes,
8484
FinalizePipeline,
8585
FuseFullThenReshapePass,
86-
FuseTransposeOpPairsPass,
86+
FuseTransposeOrPermuteOpPairsPass,
8787
RemoveNopSliceOrViewOpPass,
8888
]
8989
return pytree.tree_flatten(passes)[0]

0 commit comments

Comments
 (0)