Skip to content

Commit 2058d8b

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Update RemovePermutesAroundElementwiseOps to work with view as well (#7407)
Summary: The RemovePermutesAroundElementwiseOps pass was working well for permutes, but sometimes permutes get optimized into `view_copy` if the dimension being moved doesn't change the byte-level arrangement of the Tensor. Handle this case so we can remove more functions in these chains. Reviewed By: zonglinpeng Differential Revision: D67471456
1 parent 34e0570 commit 2058d8b

File tree

2 files changed

+69
-9
lines changed

2 files changed

+69
-9
lines changed

backends/cadence/aot/remove_ops.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import itertools
1717
import logging
1818
from dataclasses import dataclass, field
19-
from typing import Callable, cast, Dict, List, Optional, Sequence
19+
from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Union
2020

2121
import torch
2222
import torch.fx
@@ -698,16 +698,45 @@ def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None:
698698
sg.is_valid = False
699699

700700
def is_starting_permute(self, node: torch.fx.Node) -> bool:
701-
return (
702-
node.target == exir_ops.edge.aten.permute_copy.default
703-
and cast(list[int], node.args[1]) == self.to_NCHW
704-
)
701+
return self.is_boundary_permute(node, self.to_NCHW)
705702

706703
def is_ending_permute(self, node: torch.fx.Node) -> bool:
707-
return (
708-
node.target == exir_ops.edge.aten.permute_copy.default
709-
and cast(list[int], node.args[1]) == self.to_NHWC
710-
)
704+
return self.is_boundary_permute(node, self.to_NHWC)
705+
706+
@staticmethod
707+
def is_boundary_permute(node: torch.fx.Node, permute_dims: Iterable[int]) -> bool:
708+
permute_dims = list(permute_dims)
709+
if node.target == exir_ops.edge.aten.permute_copy.default:
710+
return cast(list[int], node.args[1]) == permute_dims
711+
elif node.target == exir_ops.edge.aten.view_copy.default:
712+
# If there's a view node, check if it's swapping two dimensions and
713+
# not splitting any others from the input shape.
714+
inp = node.args[0]
715+
if not isinstance(inp, torch.fx.Node):
716+
return False
717+
input_shape = inp.meta["val"].shape
718+
output_shape = node.args[1]
719+
assert isinstance(output_shape, (tuple, list))
720+
# If the shapes are equal in length, no dimension is being split or
721+
# grouped. Then check if a permute of the input shape results in the output shape.
722+
return (
723+
len(input_shape) == len(output_shape)
724+
and len(input_shape) == len(permute_dims)
725+
and RemovePermutesAroundElementwiseOps.permute_shape(
726+
input_shape, permute_dims
727+
)
728+
== output_shape
729+
)
730+
else:
731+
return False
732+
733+
@staticmethod
734+
def permute_shape(
735+
shape: Union[List[int], torch.Size], permute_dims: Iterable[int]
736+
) -> List[int]:
737+
permute_dims = list(permute_dims)
738+
assert len(shape) == len(permute_dims)
739+
return [shape[p] for p in permute_dims]
711740

712741

713742
# The following class consolidates functions to remove ops that are redundant

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,37 @@ def forward(self, x, y):
649649
][0]
650650
self.assertEqual(cat.args[1], 3)
651651

652+
def test_remove_permutes_around_concat_with_views(self) -> None:
653+
class M(torch.nn.Module):
654+
def forward(self, x, y):
655+
# Mix and match views that are permutes and actual permutes. Both
656+
# should be removed.
657+
x = x.view(1, 1, 4, 4)
658+
y = torch.permute(y, [0, 3, 1, 2])
659+
z = torch.cat((x, y), 1)
660+
return z.view(1, 4, 4, 8)
661+
662+
inputs = (torch.randn(1, 4, 4, 1), torch.randn(1, 4, 4, 7))
663+
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
664+
p = RemovePermutesAroundElementwiseOps()
665+
graph_module = cast(PassResult, p(graph_module)).graph_module
666+
667+
# Expect 0 permutes and views to remain.
668+
self.assertEqual(
669+
count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
670+
)
671+
self.assertEqual(
672+
count_node(graph_module, exir_ops.edge.aten.view_copy.default), 0
673+
)
674+
675+
# verify that cat was updated correctly
676+
cat = [
677+
n
678+
for n in graph_module.graph.nodes
679+
if n.target == exir_ops.edge.aten.cat.default
680+
][0]
681+
self.assertEqual(cat.args[1], 3)
682+
652683
def test_remove_permutes_around_elemwise_ops_noop(self) -> None:
653684
class M(torch.nn.Module):
654685
def __init__(self):

0 commit comments

Comments
 (0)