Skip to content

fix transpose / permutations fusion pass #10780

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
"""
Fuse transpose or permute op pairs to a single view op.
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
This happens when op2(op1) == identity, modulo unitary dimensions.
'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
"""

# A list of ops that can be bypassed when looking for a
Expand All @@ -908,7 +911,7 @@ def can_fuse_for_chain(
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
return False

# checking that permut2(permut1(identify)) == identity
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
ident_dims = list(range(len(input_shape)))
# this mapping helps to handle both transpose and permutations
Expand All @@ -918,14 +921,20 @@ def can_fuse_for_chain(
}
in_dims = f[producer.target](producer, ident_dims)
out_dims = f[consumer.target](consumer, in_dims)
return out_dims == ident_dims
# Filtering out unitary dimensions
non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1]
non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1]
return non_unit_out_dims == non_unit_ident_dims

def get_fused_node(
self,
producer: torch.fx.Node,
consumer: torch.fx.Node,
graph_module: torch.fx.GraphModule,
) -> torch.fx.Node:
# This step is important because of how we can fuse transpositions that are not perfectly
# reverse one of another but will be fused if there are unitary dimensions.
# The fused operation must have the same output shape as the consumer.
output_shape = consumer.meta["val"].shape
with graph_module.graph.inserting_after(consumer):
view = graph_module.graph.call_function(
Expand Down
44 changes: 44 additions & 0 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,28 @@ def _create_operator(
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
False,
),
# transpose -> quant -> transpose is not the reverse BUT there is a UNITARY dimension
# so it ends up being the same on memory => fuse
(
True,
[0, 1],
True,
[0, 2],
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
True,
[5, 40, 1],
),
# transpose -> quant -> transpose is not the reverse, and unitary dimensions
# don't help => don't fuse
(
True,
[0, 1],
True,
[1, 3],
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
False,
[5, 40, 1, 4],
),
# permutation -> quant -> opposite permutation => fuse
(
False,
Expand Down Expand Up @@ -622,6 +644,28 @@ def _create_operator(
False,
[4, 4, 4],
),
# permutation -> quant -> a non reverse permutation BUT there is a UNITARY dimension
# so it ends up being the same on memory => fuse
(
False,
[1, 3, 2, 0],
False,
[3, 2, 1, 0],
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
True,
[3, 1, 8, 10],
),
# permutation -> quant -> a non reverse permutation, and unitary dimensions
# don't help => don't fuse
(
False,
[1, 3, 2, 0],
False,
[3, 1, 2, 0],
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
False,
[3, 1, 8, 10],
),
# transpose -> quant -> transpose as a permutation => fuse
(
True,
Expand Down
Loading