14
14
import operator
15
15
from collections import deque
16
16
from numbers import Number
17
- from typing import cast , Sequence
17
+ from typing import Any , Callable , cast
18
18
19
19
# Import these for the cadence function signatures.
20
20
import executorch .backends .cadence .aot .ops_registrations # noqa: F401
@@ -881,9 +881,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
881
881
882
882
883
883
@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
884
- class FuseTransposeOpPairsPass (FuseOpPairsAcrossBranchesPass ):
884
+ class FuseTransposeOrPermuteOpPairsPass (FuseOpPairsAcrossBranchesPass ):
885
885
"""
886
- Fuse transpose op pairs to a single view op.
886
+ Fuse transpose or permute op pairs to a single view op.
887
+ (transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
887
888
"""
888
889
889
890
# A list of ops that can be bypassed when looking for a
@@ -907,42 +908,17 @@ def can_fuse_for_chain(
907
908
if not super ().can_fuse_for_chain (producer , consumer , consumer_op_packets ):
908
909
return False
909
910
910
- def get_dims (node : torch .fx .Node ) -> tuple [int , int ]:
911
- def canonicalize (dim : int ) -> int :
912
- if dim < 0 :
913
- dim += len (node .meta ["val" ].shape )
914
- return dim
915
-
916
- return tuple (canonicalize (cast (int , d )) for d in node .args [1 :3 ])
917
-
918
- def is_equivalent (
919
- shape : Sequence [int ],
920
- transpose0 : tuple [int , int ],
921
- transpose1 : tuple [int , int ],
922
- ) -> bool :
923
- def permute_order (
924
- order : Sequence [int ], dims : tuple [int , int ]
925
- ) -> Sequence [int ]:
926
- new_order = list (order )
927
- new_order [dims [0 ]], new_order [dims [1 ]] = (
928
- new_order [dims [1 ]],
929
- new_order [dims [0 ]],
930
- )
931
- return new_order
932
-
933
- order = permute_order (range (len (shape )), transpose0 )
934
- order = permute_order (order , transpose1 )
935
-
936
- non_unit_dims = [dim for dim in range (len (shape )) if shape [dim ] != 1 ]
937
- non_unit_dims_permuted = [dim for dim in order if shape [dim ] != 1 ]
938
-
939
- return non_unit_dims == non_unit_dims_permuted
940
-
941
- return is_equivalent (
942
- cast (torch .fx .Node , producer .args [0 ]).meta ["val" ].shape ,
943
- get_dims (producer ),
944
- get_dims (consumer ),
945
- )
911
+ # checking that permut2(permut1(identify)) == identity
912
+ input_shape = cast (torch .fx .Node , producer .args [0 ]).meta ["val" ].shape
913
+ ident_dims = list (range (len (input_shape )))
914
+ # this mapping helps to handle both transpose and permutations
915
+ f : dict [Any , Callable ] = {
916
+ exir_ops .edge .aten .transpose_copy .int : get_transposed_dims ,
917
+ exir_ops .edge .aten .permute_copy .default : get_permuted_dims ,
918
+ }
919
+ in_dims = f [producer .target ](producer , ident_dims )
920
+ out_dims = f [consumer .target ](consumer , in_dims )
921
+ return out_dims == ident_dims
946
922
947
923
def get_fused_node (
948
924
self ,
@@ -960,11 +936,17 @@ def get_fused_node(
960
936
return view
961
937
962
938
def call (self , graph_module : torch .fx .GraphModule ) -> PassResult :
963
- # Remove any dequantize op that has only quantize ops as its users .
939
+ # Remove any transpose/permutation op pair that cancel each other .
964
940
self .find_and_fuse (
965
941
graph_module ,
966
- producer_op_packets = {exir_ops .edge .aten .transpose_copy },
967
- consumer_op_packets = {exir_ops .edge .aten .transpose_copy },
942
+ producer_op_packets = {
943
+ exir_ops .edge .aten .transpose_copy ,
944
+ exir_ops .edge .aten .permute_copy ,
945
+ },
946
+ consumer_op_packets = {
947
+ exir_ops .edge .aten .transpose_copy ,
948
+ exir_ops .edge .aten .permute_copy ,
949
+ },
968
950
bypass_ops = self .bypass_ops ,
969
951
)
970
952
result = super ().call (graph_module )
@@ -1028,5 +1010,5 @@ class CadenceFuseOpsInGraph:
1028
1010
FuseQuantDequantToRequantizePass ,
1029
1011
FuseMulIntoDequantPass ,
1030
1012
FuseFullThenReshapePass ,
1031
- FuseTransposeOpPairsPass ,
1013
+ FuseTransposeOrPermuteOpPairsPass ,
1032
1014
]
0 commit comments