Skip to content

Fix CatFromSliceCopyPass indexing issue. #10913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions backends/cadence/aot/pass_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,34 @@ def nodes_not_adjacent_in_gm(
if node.next.target == succ_target:
return False
return True


def get_arg(
node: torch.fx.Node,
arg_index: int,
kwarg_name: str,
*,
default: torch.fx.node.Argument = None,
) -> torch.fx.node.Argument:
"""
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
return default.
"""
if arg_index < len(node.args):
return node.args[arg_index]
elif kwarg_name in node.kwargs:
return node.kwargs[kwarg_name]
else:
return default


def set_arg(
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
) -> None:
"""
Set the arg at arg_index if it exists, otherwise set the kwarg.
"""
if arg_index < len(node.args):
node.update_arg(arg_index, value)
else:
node.update_kwarg(kwarg_name, value)
97 changes: 43 additions & 54 deletions backends/cadence/aot/remove_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
import torch.fx
from executorch.backends.cadence.aot.pass_utils import (
CadencePassAttribute,
get_arg,
register_cadence_pass,
set_arg,
)

from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
Expand All @@ -37,7 +39,7 @@
from executorch.exir.pass_manager import PassManager, PassType
from executorch.exir.passes import dead_code_elimination_pass
from executorch.exir.passes.spec_prop_pass import SpecPropPass
from torch.fx.node import Argument
from torch.fx.node import Argument, Node


@register_cadence_pass(CadencePassAttribute(opt_level=0))
Expand Down Expand Up @@ -771,65 +773,52 @@ def remove_branched(


class RemoveCatFromSliceCopyPass(ExportPass):
def _remove_unused_cat( # noqa: C901
self, graph_module: torch.fx.GraphModule
) -> None:
slice_copy_nodes = [
node
for node in graph_module.graph.nodes
if node.target == exir_ops.edge.aten.slice_copy.Tensor
]
for slice_copy_node in slice_copy_nodes:
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
input_node, *other_args = slice_copy_node.args
if len(other_args) >= 1:
slice_dim = other_args[0]
if len(other_args) >= 2:
start_idx = other_args[1]
if len(other_args) >= 3:
end_idx = other_args[2]
if len(other_args) >= 4:
step = other_args[3]
if step != 1:
continue
slice_copy_dtype = slice_copy_node.meta["val"].dtype
if input_node.target != exir_ops.edge.aten.cat.default:
continue
cat_dtype = input_node.meta["val"].dtype
if slice_copy_dtype != cat_dtype:
"""
Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
to the slice_copy.
"""

def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
for slice_copy_node in graph_module.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
):
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))

if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
continue
cat_dim = input_node.args[1:]
if len(cat_dim) == 0:
cat_dim = 0

# Make sure cat and slice happens on the same dimension.
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
if cat_dim != slice_dim:
continue
cat_output_shape = input_node.meta["val"].shape
start_idx = (
cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
)
end_idx = (
cat_output_shape[cat_dim]
if end_idx > cat_output_shape[cat_dim]
else end_idx
)
base_idx = 0
cat_input_to_keep = None
for cat_input_node in input_node.args[0]:
cat_input_dtype = cat_input_node.meta["val"].dtype
if slice_copy_dtype != cat_input_dtype:
continue

# Canonicalize slice indices.
cat_output_shape = cat_node.meta["val"].shape
if start_idx is None:
start_idx = 0
elif start_idx < 0:
start_idx += cat_output_shape[cat_dim]
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
end_idx = cat_output_shape[cat_dim]
elif end_idx < 0:
end_idx += cat_output_shape[cat_dim]

offset = 0
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
cat_input_shape = cat_input_node.meta["val"].shape

# check if the slice range overlaps with the cat range
if (
base_idx <= start_idx
and end_idx <= list(cat_input_shape)[cat_dim] + base_idx
):
cat_input_to_keep = cat_input_node
# Check if the slice range overlaps with the cat input range.
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
slice_copy_node.replace_input_with(cat_node, cat_input_node)
set_arg(slice_copy_node, 2, "start", start_idx - offset)
set_arg(slice_copy_node, 3, "end", end_idx - offset)
break
base_idx += list(cat_input_shape)[cat_dim]
if cat_input_to_keep is not None:
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)

offset += cat_input_shape[cat_dim]

def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
self._remove_unused_cat(graph_module)
Expand Down
27 changes: 27 additions & 0 deletions backends/cadence/aot/tests/test_remove_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,30 @@ def forward(self, x, y):

# Ensure both cat nodes were removed
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

def test_remove_cat_from_slice_copy_second_input(self) -> None:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(2, 4))
y = builder.placeholder("y", torch.randn(2, 4))
cat = builder.call_operator(
op=exir_ops.edge.aten.cat.default,
args=((x, y), 1),
)
slice_copy = builder.call_operator(
op=exir_ops.edge.aten.slice_copy.Tensor,
args=(cat, 1, 5, 7, 1),
)
builder.output([slice_copy])
graph_module = builder.get_graph_module()

inputs = (torch.randn(2, 4), torch.randn(2, 4))
expected_outputs = graph_module(*inputs)[0]

p = RemoveCatFromSliceCopyPass()
graph_module = cast(PassResult, p(graph_module)).graph_module

# Cat should be removed.
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)

# Output should remain the same.
self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs))
Loading