Skip to content

Commit a12c7b7

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Fix CatFromSliceCopyPass indexing issue. (#10913)
Summary: Pull Request resolved: #10913 Fix indexing bug in RemoveCatFromSliceCopyPass. If the slice input becomes one of the later inputs of the cat the slicing indices should be updated accordingly. Also simplify the pass and generalize arg/kwarg support. Differential Revision: D74765369
1 parent 0a6f622 commit a12c7b7

File tree

3 files changed

+101
-54
lines changed

3 files changed

+101
-54
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,34 @@ def nodes_not_adjacent_in_gm(
157157
if node.next.target == succ_target:
158158
return False
159159
return True
160+
161+
162+
def get_arg(
163+
node: torch.fx.Node,
164+
arg_index: int,
165+
kwarg_name: str,
166+
*,
167+
default: torch.fx.node.Argument = None
168+
) -> torch.fx.node.Argument:
169+
"""
170+
Get the arg at arg_index or kwarg with arg_name of the node. If neither is found
171+
return default.
172+
"""
173+
if arg_index < len(node.args):
174+
return node.args[arg_index]
175+
elif kwarg_name in node.kwargs:
176+
return node.kwargs[kwarg_name]
177+
else:
178+
return default
179+
180+
181+
def set_arg(
182+
node: torch.fx.Node, arg_index: int, kwarg_name: str, value: torch.fx.node.Argument
183+
) -> None:
184+
"""
185+
Set the arg at arg_index if it exists, otherwise set the kwarg.
186+
"""
187+
if arg_index < len(node.args):
188+
node.update_arg(arg_index, value)
189+
else:
190+
node.update_kwarg(kwarg_name, value)

backends/cadence/aot/remove_ops.py

Lines changed: 43 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
import torch.fx
2626
from executorch.backends.cadence.aot.pass_utils import (
2727
CadencePassAttribute,
28+
get_arg,
2829
register_cadence_pass,
30+
set_arg,
2931
)
3032

3133
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
@@ -37,7 +39,7 @@
3739
from executorch.exir.pass_manager import PassManager, PassType
3840
from executorch.exir.passes import dead_code_elimination_pass
3941
from executorch.exir.passes.spec_prop_pass import SpecPropPass
40-
from torch.fx.node import Argument
42+
from torch.fx.node import Argument, Node
4143

4244

4345
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -771,65 +773,52 @@ def remove_branched(
771773

772774

773775
class RemoveCatFromSliceCopyPass(ExportPass):
774-
def _remove_unused_cat( # noqa: C901
775-
self, graph_module: torch.fx.GraphModule
776-
) -> None:
777-
slice_copy_nodes = [
778-
node
779-
for node in graph_module.graph.nodes
780-
if node.target == exir_ops.edge.aten.slice_copy.Tensor
781-
]
782-
for slice_copy_node in slice_copy_nodes:
783-
slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1
784-
input_node, *other_args = slice_copy_node.args
785-
if len(other_args) >= 1:
786-
slice_dim = other_args[0]
787-
if len(other_args) >= 2:
788-
start_idx = other_args[1]
789-
if len(other_args) >= 3:
790-
end_idx = other_args[2]
791-
if len(other_args) >= 4:
792-
step = other_args[3]
793-
if step != 1:
794-
continue
795-
slice_copy_dtype = slice_copy_node.meta["val"].dtype
796-
if input_node.target != exir_ops.edge.aten.cat.default:
797-
continue
798-
cat_dtype = input_node.meta["val"].dtype
799-
if slice_copy_dtype != cat_dtype:
776+
"""
777+
Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed
778+
to the slice_copy.
779+
"""
780+
781+
def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None:
782+
for slice_copy_node in graph_module.graph.find_nodes(
783+
op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor
784+
):
785+
cat_node = cast(Node, get_arg(slice_copy_node, 0, "input"))
786+
slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0))
787+
start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None))
788+
end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None))
789+
step = cast(int, get_arg(slice_copy_node, 4, "step", default=1))
790+
791+
if cat_node.target != exir_ops.edge.aten.cat.default or step != 1:
800792
continue
801-
cat_dim = input_node.args[1:]
802-
if len(cat_dim) == 0:
803-
cat_dim = 0
793+
794+
# Make sure cat and slice happens on the same dimension.
795+
cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0))
804796
if cat_dim != slice_dim:
805797
continue
806-
cat_output_shape = input_node.meta["val"].shape
807-
start_idx = (
808-
cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx
809-
)
810-
end_idx = (
811-
cat_output_shape[cat_dim]
812-
if end_idx > cat_output_shape[cat_dim]
813-
else end_idx
814-
)
815-
base_idx = 0
816-
cat_input_to_keep = None
817-
for cat_input_node in input_node.args[0]:
818-
cat_input_dtype = cat_input_node.meta["val"].dtype
819-
if slice_copy_dtype != cat_input_dtype:
820-
continue
798+
799+
# Canonicalize slice indices.
800+
cat_output_shape = cat_node.meta["val"].shape
801+
if start_idx is None:
802+
start_idx = 0
803+
elif start_idx < 0:
804+
start_idx += cat_output_shape[cat_dim]
805+
if end_idx is None or end_idx > cat_output_shape[cat_dim]:
806+
end_idx = cat_output_shape[cat_dim]
807+
elif end_idx < 0:
808+
end_idx += cat_output_shape[cat_dim]
809+
810+
offset = 0
811+
for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")):
821812
cat_input_shape = cat_input_node.meta["val"].shape
822813

823-
# check if the slice range overlaps with the cat range
824-
if (
825-
base_idx <= start_idx
826-
and end_idx <= list(cat_input_shape)[cat_dim] + base_idx
827-
):
828-
cat_input_to_keep = cat_input_node
814+
# Check if the slice range overlaps with the cat input range.
815+
if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]:
816+
slice_copy_node.replace_input_with(cat_node, cat_input_node)
817+
set_arg(slice_copy_node, 2, "start", start_idx - offset)
818+
set_arg(slice_copy_node, 3, "end", end_idx - offset)
829819
break
830-
base_idx += list(cat_input_shape)[cat_dim]
831-
if cat_input_to_keep is not None:
832-
slice_copy_node.replace_input_with(input_node, cat_input_to_keep)
820+
821+
offset += cat_input_shape[cat_dim]
833822

834823
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
835824
self._remove_unused_cat(graph_module)

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,30 @@ def forward(self, x, y):
864864

865865
# Ensure both cat nodes were removed
866866
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
867+
868+
def test_remove_cat_from_slice_copy_second_input(self) -> None:
869+
builder = GraphBuilder()
870+
x = builder.placeholder("x", torch.randn(2, 4))
871+
y = builder.placeholder("y", torch.randn(2, 4))
872+
cat = builder.call_operator(
873+
op=exir_ops.edge.aten.cat.default,
874+
args=((x, y), 1),
875+
)
876+
slice_copy = builder.call_operator(
877+
op=exir_ops.edge.aten.slice_copy.Tensor,
878+
args=(cat, 1, 5, 7, 1),
879+
)
880+
builder.output([slice_copy])
881+
graph_module = builder.get_graph_module()
882+
883+
inputs = (torch.randn(2, 4), torch.randn(2, 4))
884+
expected_outputs = graph_module(*inputs)[0]
885+
886+
p = RemoveCatFromSliceCopyPass()
887+
graph_module = cast(PassResult, p(graph_module)).graph_module
888+
889+
# Cat should be removed.
890+
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
891+
892+
# Output should remain the same.
893+
self.assertTrue(torch.equal(graph_module(*inputs)[0], expected_outputs))

0 commit comments

Comments
 (0)