Skip to content

Commit 7175ca4

Browse files
authored
Fix CatFromSliceCopyPass indexing issue.
Differential Revision: D74765369 Pull Request resolved: #10913
1 parent 78fe7ee commit 7175ca4

File tree

3 files changed

+101
-54
lines changed

3 files changed

+101
-54
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,34 @@ def nodes_not_adjacent_in_gm(
157157
if node.next.target == succ_target:
158158
return False
159159
return True
160+
161+
162+
def get_arg(
163+
node: torch.fx.Node,
164+
arg_index: int,
165+
kwarg_name: str,
166+
*,
167+
default: torch.fx.node.Argument = None,
168+
) -> torch.fx.node.Argument:
169+
"""
170+
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
171+
return default.
172+
"""
173+
if arg_index < len(node.args):
174+
return node.args[arg_index]
175+
elif kwarg_name in node.kwargs:
176+
return node.kwargs[kwarg_name]
177+
else:
178+
return default
179+
180+
181+
def set_arg(
182+
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
183+
) -> None:
184+
"""
185+
Set the arg at arg_index if it exists, otherwise set the kwarg.
186+
"""
187+
if arg_index < len(node.args):
188+
node.update_arg(arg_index, value)
189+
else:
190+
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/remove_ops.py

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import torch.fx
2626
from executorch.backends.cadence.aot.pass_utils import (
2727
CadencePassAttribute,
28+
get_arg,
2829
register_cadence_pass,
30+
set_arg,
2931
)
3032

3133
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
@@ -37,7 +39,7 @@
3739
from executorch.exir.pass_manager import PassManager, PassType
3840
from executorch.exir.passes import dead_code_elimination_pass
3941
from executorch.exir.passes.spec_prop_pass import SpecPropPass
40-
from torch.fx.node import Argument
42+
from torch.fx.node import Argument, Node
4143

4244

4345
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -771,65 +773,52 @@ def remove_branched(
771773

772774

773775
class RemoveCatFromSliceCopyPass(ExportPass):
774-
def _remove_unused_cat( # noqa: C901
775-
self, graph_module: torch.fx.GraphModule
776-
) -> None:
777-
slice_copy_nodes = [
778-
node
779-
for node in graph_module.graph.nodes
780-
if node.target == exir_ops.edge.aten.slice_copy.Tensor
781-
]
782-
for slice_copy_node in slice_copy_nodes:
783-
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
784-
input_node, *other_args = slice_copy_node.args
785-
if len(other_args) >= 1:
786-
slice_dim = other_args[0]
787-
if len(other_args) >= 2:
788-
start_idx = other_args[1]
789-
if len(other_args) >= 3:
790-
end_idx = other_args[2]
791-
if len(other_args) >= 4:
792-
step = other_args[3]
793-
if step != 1:
794-
continue
795-
slice_copy_dtype = slice_copy_node.meta["val"].dtype
796-
if input_node.target != exir_ops.edge.aten.cat.default:
797-
continue
798-
cat_dtype = input_node.meta["val"].dtype
799-
if slice_copy_dtype != cat_dtype:
776+
"""
777+
Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
778+
to the slice_copy.
779+
"""
780+
781+
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
782+
for slice_copy_node in graph_module.graph.find_nodes(
783+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
784+
):
785+
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
786+
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
787+
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
788+
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
789+
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
790+
791+
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
800792
continue
801-
cat_dim = input_node.args[1:]
802-
if len(cat_dim) == 0:
803-
cat_dim = 0
793+
794+
# Make sure cat and slice happens on the same dimension.
795+
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
804796
if cat_dim != slice_dim:
805797
continue
806-
cat_output_shape = input_node.meta["val"].shape
807-
start_idx = (
808-
cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
809-
)
810-
end_idx = (
811-
cat_output_shape[cat_dim]
812-
if end_idx > cat_output_shape[cat_dim]
813-
else end_idx
814-
)
815-
base_idx = 0
816-
cat_input_to_keep = None
817-
for cat_input_node in input_node.args[0]:
818-
cat_input_dtype = cat_input_node.meta["val"].dtype
819-
if slice_copy_dtype != cat_input_dtype:
820-
continue
798+
799+
# Canonicalize slice indices.
800+
cat_output_shape = cat_node.meta["val"].shape
801+
if start_idx is None:
802+
start_idx = 0
803+
elif start_idx < 0:
804+
start_idx += cat_output_shape[cat_dim]
805+
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
806+
end_idx = cat_output_shape[cat_dim]
807+
elif end_idx < 0:
808+
end_idx += cat_output_shape[cat_dim]
809+
810+
offset = 0
811+
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
821812
cat_input_shape = cat_input_node.meta["val"].shape
822813

823-
# check if the slice range overlaps with the cat range
824-
if (
825-
base_idx <= start_idx
826-
and end_idx <= list(cat_input_shape)[cat_dim] + base_idx
827-
):
828-
cat_input_to_keep = cat_input_node
814+
# Check if the slice range overlaps with the cat input range.
815+
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
816+
slice_copy_node.replace_input_with(cat_node, cat_input_node)
817+
set_arg(slice_copy_node, 2, "start", start_idx - offset)
818+
set_arg(slice_copy_node, 3, "end", end_idx - offset)
829819
break
830-
base_idx += list(cat_input_shape)[cat_dim]
831-
if cat_input_to_keep is not None:
832-
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)
820+
821+
offset += cat_input_shape[cat_dim]
833822

834823
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
835824
self._remove_unused_cat(graph_module)

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,30 @@ def forward(self, x, y):
864864

865865
# Ensure both cat nodes were removed
866866
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
867+
868+
def test_remove_cat_from_slice_copy_second_input(self) -> None:
869+
builder = GraphBuilder()
870+
x = builder.placeholder("x", torch.randn(2, 4))
871+
y = builder.placeholder("y", torch.randn(2, 4))
872+
cat = builder.call_operator(
873+
op=exir_ops.edge.aten.cat.default,
874+
args=((x, y), 1),
875+
)
876+
slice_copy = builder.call_operator(
877+
op=exir_ops.edge.aten.slice_copy.Tensor,
878+
args=(cat, 1, 5, 7, 1),
879+
)
880+
builder.output([slice_copy])
881+
graph_module = builder.get_graph_module()
882+
883+
inputs = (torch.randn(2, 4), torch.randn(2, 4))
884+
expected_outputs = graph_module(*inputs)[0]
885+
886+
p = RemoveCatFromSliceCopyPass()
887+
graph_module = cast(PassResult, p(graph_module)).graph_module
888+
889+
# Cat should be removed.
890+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
891+
892+
# Output should remain the same.
893+
self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs))

0 commit comments

Comments
 (0)