Skip to content

Permute elimination pass fixes. #10662

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ python_unittest(
":compiler",
"//caffe2:torch",
"//executorch/backends/cadence/aot:compiler",
"//executorch/backends/cadence/aot:graph_builder",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:pass_utils",
"//executorch/backends/cadence/aot:remove_ops",
Expand Down
313 changes: 138 additions & 175 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
# in a context outside of Jarvis', so exercise caution while invoking this in a
# pass list outside of Jarvis.

import itertools
import logging
from dataclasses import dataclass, field
from typing import Callable, cast, Dict, Iterable, List, Optional, Sequence, Union
from typing import cast, List, Optional, Sequence

import torch
import torch.fx
Expand Down Expand Up @@ -538,211 +537,175 @@ def call_operator(
return super().call_operator(op, args, kwargs, meta)


@register_cadence_pass(CadencePassAttribute(opt_level=1))
@register_cadence_pass(CadencePassAttribute(opt_level=2))
class RemovePermutesAroundElementwiseOps(ExportPass):
"""
Looks for subgraphs of elementwise ops sandwiched between permutes and removes those
permutes if possible. This pass is targeted at models where delegated subgraphs
must be in NHWC format, so there's usually a to_NHWC permute before each delegate and
a to_NCHW permute after it. If all the ops between two delegates are elementwise ops
then these permutes can be safely removed.
Allows special handling for certain non-elementwise ops that can be easily updated based on
the permute's parameter, such as mean and cat
permutes if possible.
Allows special handling for certain non-elementwise ops that can be easily updated
based on the permute's parameter such as mean, cat, and slice.
"""

@dataclass()
class Subgraph:
"""
Keeps track of nodes grouped as a subgraph between two sets of permutes
"""

start_permutes: set[torch.fx.Node] = field(default_factory=set)
end_permutes: set[torch.fx.Node] = field(default_factory=set)
intermediate_nodes: set[torch.fx.Node] = field(default_factory=set)
is_valid: bool = True

elementwise_ops: set[EdgeOpOverload] = {
start_permute: list[int]
end_permute: list[int]
# Nodes in the subgraph, does not include permutes.
nodes: set[torch.fx.Node] = field(default_factory=set)
# Incoming edges to the subgraph from permute nodes.
edges_in: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)
# Outgoing edges of the subgraph to permute nodes.
edges_out: set[tuple[torch.fx.Node, torch.fx.Node]] = field(default_factory=set)

permutable_ops: set[EdgeOpOverload] = {
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
exir_ops.edge.cadence.quantize_per_tensor.default,
exir_ops.edge.cadence.dequantize_per_tensor.default,
# Ops that require special handling.
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.mean.dim,
exir_ops.edge.aten.slice_copy.Tensor,
}

# must be initialized in the constructor
special_handling: Dict[EdgeOpOverload, Callable[[torch.fx.Node], None]] = {}

to_NCHW = [0, 3, 1, 2]
to_NHWC = [0, 2, 3, 1]

def __init__(self) -> None:
super().__init__()
self.visited: set[object] = set()
self.special_handling = {
exir_ops.edge.aten.mean.dim: self.handle_mean_dim,
exir_ops.edge.aten.cat.default: self.handle_cat,
}

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self.visited = set()
subgraphs_found: list[RemovePermutesAroundElementwiseOps.Subgraph] = []
processed_nodes: set[torch.fx.Node] = set()
for node in graph_module.graph.nodes:
sg = self.Subgraph()
self.start_search(node, sg)
if self.is_valid_subgraph(sg):
logging.debug(f"Found valid subgraph: {sg}")
self.handle_subgraph(graph_module, sg)
if node.target != exir_ops.edge.aten.permute_copy.default:
continue

result = super().call(graph_module)
return result
start_permute = self.get_permutation(node)
# Expected end permutation for the subgraph.
end_permute = [start_permute.index(i) for i in range(len(start_permute))]

def handle_mean_dim(self, mean_dim: torch.fx.Node) -> None:
assert mean_dim.target == exir_ops.edge.aten.mean.dim
args = list(mean_dim.args)
args[1] = [self.to_NCHW[dim] for dim in cast(list[int], args[1])]
mean_dim.args = tuple(args)
for user in node.users:
if user.target not in self.permutable_ops:
continue
# Create a separate subgraph for each user since there may be cases
# where only a portion of the users are permutable.
subgraph = self.Subgraph(start_permute, end_permute)
if self.visit(user, subgraph, processed_nodes):
subgraphs_found.append(subgraph)
for node in subgraph.nodes:
processed_nodes.add(node)

def handle_cat(self, cat: torch.fx.Node) -> None:
assert cat.target == exir_ops.edge.aten.cat.default
args = list(cat.args)
args[1] = self.to_NCHW[cast(int, args[1])]
cat.args = tuple(args)
for subgraph in subgraphs_found:
self.permute_subgraph(subgraph)

def is_valid_subgraph(self, sg: Subgraph) -> bool:
return (
sg.is_valid
and len(sg.start_permutes) > 0
and len(sg.end_permutes) > 0
and len(sg.intermediate_nodes) > 0
)
graph_module.graph.eliminate_dead_code()
graph_module.recompile()

def handle_subgraph(self, graph_module: torch.fx.GraphModule, sg: Subgraph) -> None:
for permute in itertools.chain(sg.start_permutes, sg.end_permutes):
permute.replace_all_uses_with(permute.args[0]) # pyre-fixme[6]
return super().call(graph_module)

for node in sg.intermediate_nodes:
if node.target in self.special_handling:
self.special_handling[node.target](node)
def visit(
self,
node: torch.fx.Node,
subgraph: Subgraph,
processed_nodes: set[torch.fx.Node],
) -> bool:
if node in subgraph.nodes:
return True
if node in processed_nodes or not self.is_node_permutable(node):
return False
subgraph.nodes.add(node)

# Traverse downstream:
for user in node.users:
# Output should either go to a matching permute or another permutable op.
if user.target == exir_ops.edge.aten.permute_copy.default:
if self.get_permutation(user) != subgraph.end_permute:
return False
subgraph.edges_out.add((node, user))
elif not self.visit(user, subgraph, processed_nodes):
return False

graph_module.recompile()
graph_module.graph.eliminate_dead_code()
# Traverse upstream:
for inp in node.all_input_nodes:
# Input should either come from a matching permute or another permutable op.
if inp.target == exir_ops.edge.aten.permute_copy.default:
if self.get_permutation(inp) != subgraph.start_permute:
return False
subgraph.edges_in.add((inp, node))
elif not self.visit(inp, subgraph, processed_nodes):
return False

def start_search(self, node: torch.fx.Node, sg: Subgraph) -> None:
if node in self.visited:
return
return True

if self.is_starting_permute(node):
sg.start_permutes.add(node)
self.visited.add(node)
for user in node.users:
self.search_down(user, sg)

def search_up(self, node: object, sg: Subgraph) -> None:
# non-nodes can be ignored. These would be arguments like integers or lists
# of integers, which don't affect the subgraph validity or inclusion set.
if not isinstance(node, torch.fx.Node):
return

if node.op == "placeholder":
# If we reach a placeholder or other terminal node without encountering
# a start permute, then the subgraph is invalid.
# This could be because in the add(x, y) case where x is permuted and
# y is a graph input, we can't remove the permute on x because it might
# become two different shapes that don't broadcast together.
# TODO: Adding a permute on y could be the more optimal solution,
# but perhaps not in all cases, say if x is small and y is very large.
# This transform prefers to be safe over optimal for now.
sg.is_valid = False
return

if node in self.visited:
return

self.visited.add(node)

if self.is_starting_permute(node):
sg.start_permutes.add(node)
for user in node.users:
self.search_down(user, sg)
else:
self.traverse_intermediate_node(node, sg)

def search_down(self, node: torch.fx.Node, sg: Subgraph) -> None:
if node in self.visited or self.is_starting_permute(node):
return

self.visited.add(node)

if self.is_ending_permute(node):
sg.end_permutes.add(node)
for arg in node.args:
if isinstance(arg, list):
for elem in arg:
self.search_up(elem, sg)
else:
self.search_up(arg, sg)
def is_node_permutable(self, node: torch.fx.Node) -> bool:
if node.target not in self.permutable_ops:
return False
if node.target == exir_ops.edge.aten.mean.dim:
# keepdim should be True.
if len(node.args) >= 3:
if not node.args[2]:
return False
elif "keepdim" in node.kwargs:
if not node.kwargs["keepdim"]:
return False
else:
# Default keepdim is False.
return False
return True

def permute_subgraph(self, subgraph: Subgraph) -> None:
# Skip incoming permutes.
for inp, out in subgraph.edges_in:
assert inp.target == exir_ops.edge.aten.permute_copy.default
if len(inp.args) >= 1:
out.replace_input_with(inp, cast(torch.fx.Node, inp.args[0]))
else:
out.replace_input_with(inp, cast(torch.fx.Node, inp.kwargs["input"]))

# Skip outgoing permutes.
for inp, out in subgraph.edges_out:
assert out.target == exir_ops.edge.aten.permute_copy.default
out.replace_all_uses_with(inp)

# Handle dimension related node arguments.
for node in subgraph.nodes:
if node.target == exir_ops.edge.aten.cat.default:
self.update_cat(node, subgraph.start_permute)
elif node.target == exir_ops.edge.aten.mean.dim:
self.update_mean_dim(node, subgraph.start_permute)
elif node.target == exir_ops.edge.aten.slice_copy.Tensor:
self.update_slice_copy(node, subgraph.start_permute)

def update_cat(self, node: torch.fx.Node, start_permute: list[int]) -> None:
if len(node.args) >= 2:
node.update_arg(1, start_permute[cast(int, node.args[1])])
elif "dim" in node.kwargs:
node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])
else:
self.traverse_intermediate_node(node, sg)

def traverse_intermediate_node(self, node: torch.fx.Node, sg: Subgraph) -> None:
if node.target in self.elementwise_ops:
sg.intermediate_nodes.add(node)
for arg in node.args:
if isinstance(arg, list):
for elem in arg:
self.search_up(elem, sg)
else:
self.search_up(arg, sg)

for user in node.users:
self.search_down(user, sg)
# Default cat dim is 0.
node.update_kwarg("dim", start_permute[0])

else:
sg.is_valid = False

def is_starting_permute(self, node: torch.fx.Node) -> bool:
return self.is_boundary_permute(node, self.to_NCHW)

def is_ending_permute(self, node: torch.fx.Node) -> bool:
return self.is_boundary_permute(node, self.to_NHWC)

@staticmethod
def is_boundary_permute(node: torch.fx.Node, permute_dims: Iterable[int]) -> bool:
permute_dims = list(permute_dims)
if node.target == exir_ops.edge.aten.permute_copy.default:
return cast(list[int], node.args[1]) == permute_dims
elif node.target == exir_ops.edge.aten.view_copy.default:
# If there's a view node, check if it's swapping two dimensions and
# not splitting any others from the input shape.
inp = node.args[0]
if not isinstance(inp, torch.fx.Node):
return False
input_shape = inp.meta["val"].shape
output_shape = node.args[1]
assert isinstance(output_shape, (tuple, list))
# If the shapes are equal in length, no dimension is being split or
# grouped. Then check if a permute of the input shape results in the output shape.
return (
len(input_shape) == len(output_shape)
and len(input_shape) == len(permute_dims)
and RemovePermutesAroundElementwiseOps.permute_shape(
input_shape, permute_dims
)
== output_shape
def update_mean_dim(self, node: torch.fx.Node, start_permute: list[int]) -> None:
if len(node.args) >= 2:
node.update_arg(
1, [start_permute[dim] for dim in cast(list[int], node.args[1])]
)
else:
return False
node.update_kwarg(
"dim",
[start_permute[dim] for dim in cast(list[int], node.kwargs["dim"])],
)

@staticmethod
def permute_shape(
shape: Union[List[int], torch.Size], permute_dims: Iterable[int]
) -> List[int]:
permute_dims = list(permute_dims)
assert len(shape) == len(permute_dims)
return [shape[p] for p in permute_dims]
def update_slice_copy(self, node: torch.fx.Node, start_permute: list[int]) -> None:
if len(node.args) >= 2:
node.update_arg(1, start_permute[cast(int, node.args[1])])
else:
node.update_kwarg("dim", start_permute[cast(int, node.kwargs["dim"])])

def get_permutation(self, permute_node: torch.fx.Node) -> list[int]:
assert permute_node.target == exir_ops.edge.aten.permute_copy.default
if len(permute_node.args) >= 2:
return cast(list[int], permute_node.args[1])
assert "dim" in permute_node.kwargs
return cast(list[int], permute_node.kwargs["dim"])


@register_cadence_pass(CadencePassAttribute(opt_level=1))
Expand Down
Loading
Loading