|
25 | 25 | import torch.fx
|
26 | 26 | from executorch.backends.cadence.aot.pass_utils import (
|
27 | 27 | CadencePassAttribute,
|
| 28 | + get_arg, |
28 | 29 | register_cadence_pass,
|
| 30 | + set_arg, |
29 | 31 | )
|
30 | 32 |
|
31 | 33 | from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
|
|
37 | 39 | from executorch.exir.pass_manager import PassManager, PassType
|
38 | 40 | from executorch.exir.passes import dead_code_elimination_pass
|
39 | 41 | from executorch.exir.passes.spec_prop_pass import SpecPropPass
|
40 |
| -from torch.fx.node import Argument |
| 42 | +from torch.fx.node import Argument, Node |
41 | 43 |
|
42 | 44 |
|
43 | 45 | @register_cadence_pass(CadencePassAttribute(opt_level=0))
|
@@ -771,65 +773,52 @@ def remove_branched(
|
771 | 773 |
|
772 | 774 |
|
773 | 775 | class RemoveCatFromSliceCopyPass(ExportPass):
|
774 |
| - def _remove_unused_cat( # noqa: C901 |
775 |
| - self, graph_module: torch.fx.GraphModule |
776 |
| - ) -> None: |
777 |
| - slice_copy_nodes = [ |
778 |
| - node |
779 |
| - for node in graph_module.graph.nodes |
780 |
| - if node.target == exir_ops.edge.aten.slice_copy.Tensor |
781 |
| - ] |
782 |
| - for slice_copy_node in slice_copy_nodes: |
783 |
| - slice_dim, start_idx, end_idx, step = 0, 0, float("inf"), 1 |
784 |
| - input_node, *other_args = slice_copy_node.args |
785 |
| - if len(other_args) >= 1: |
786 |
| - slice_dim = other_args[0] |
787 |
| - if len(other_args) >= 2: |
788 |
| - start_idx = other_args[1] |
789 |
| - if len(other_args) >= 3: |
790 |
| - end_idx = other_args[2] |
791 |
| - if len(other_args) >= 4: |
792 |
| - step = other_args[3] |
793 |
| - if step != 1: |
794 |
| - continue |
795 |
| - slice_copy_dtype = slice_copy_node.meta["val"].dtype |
796 |
| - if input_node.target != exir_ops.edge.aten.cat.default: |
797 |
| - continue |
798 |
| - cat_dtype = input_node.meta["val"].dtype |
799 |
| - if slice_copy_dtype != cat_dtype: |
| 776 | + """ |
| 777 | + Simplifies cat->slice_copy chains where one of the cat inputs can be directly passed |
| 778 | + to the slice_copy. |
| 779 | + """ |
| 780 | + |
| 781 | + def _remove_unused_cat(self, graph_module: torch.fx.GraphModule) -> None: |
| 782 | + for slice_copy_node in graph_module.graph.find_nodes( |
| 783 | + op="call_function", target=exir_ops.edge.aten.slice_copy.Tensor |
| 784 | + ): |
| 785 | + cat_node = cast(Node, get_arg(slice_copy_node, 0, "input")) |
| 786 | + slice_dim = cast(int, get_arg(slice_copy_node, 1, "dim", default=0)) |
| 787 | + start_idx = cast(int, get_arg(slice_copy_node, 2, "start", default=None)) |
| 788 | + end_idx = cast(int, get_arg(slice_copy_node, 3, "end", default=None)) |
| 789 | + step = cast(int, get_arg(slice_copy_node, 4, "step", default=1)) |
| 790 | + |
| 791 | + if cat_node.target != exir_ops.edge.aten.cat.default or step != 1: |
800 | 792 | continue
|
801 |
| - cat_dim = input_node.args[1:] |
802 |
| - if len(cat_dim) == 0: |
803 |
| - cat_dim = 0 |
| 793 | + |
| 794 | + # Make sure cat and slice happens on the same dimension. |
| 795 | + cat_dim = cast(Node, get_arg(cat_node, 1, "dim", default=0)) |
804 | 796 | if cat_dim != slice_dim:
|
805 | 797 | continue
|
806 |
| - cat_output_shape = input_node.meta["val"].shape |
807 |
| - start_idx = ( |
808 |
| - cat_output_shape[cat_dim] + start_idx if start_idx < 0 else start_idx |
809 |
| - ) |
810 |
| - end_idx = ( |
811 |
| - cat_output_shape[cat_dim] |
812 |
| - if end_idx > cat_output_shape[cat_dim] |
813 |
| - else end_idx |
814 |
| - ) |
815 |
| - base_idx = 0 |
816 |
| - cat_input_to_keep = None |
817 |
| - for cat_input_node in input_node.args[0]: |
818 |
| - cat_input_dtype = cat_input_node.meta["val"].dtype |
819 |
| - if slice_copy_dtype != cat_input_dtype: |
820 |
| - continue |
| 798 | + |
| 799 | + # Canonicalize slice indices. |
| 800 | + cat_output_shape = cat_node.meta["val"].shape |
| 801 | + if start_idx is None: |
| 802 | + start_idx = 0 |
| 803 | + elif start_idx < 0: |
| 804 | + start_idx += cat_output_shape[cat_dim] |
| 805 | + if end_idx is None or end_idx > cat_output_shape[cat_dim]: |
| 806 | + end_idx = cat_output_shape[cat_dim] |
| 807 | + elif end_idx < 0: |
| 808 | + end_idx += cat_output_shape[cat_dim] |
| 809 | + |
| 810 | + offset = 0 |
| 811 | + for cat_input_node in cast(List[Node], get_arg(cat_node, 0, "tensors")): |
821 | 812 | cat_input_shape = cat_input_node.meta["val"].shape
|
822 | 813 |
|
823 |
| - # check if the slice range overlaps with the cat range |
824 |
| - if ( |
825 |
| - base_idx <= start_idx |
826 |
| - and end_idx <= list(cat_input_shape)[cat_dim] + base_idx |
827 |
| - ): |
828 |
| - cat_input_to_keep = cat_input_node |
| 814 | + # Check if the slice range overlaps with the cat input range. |
| 815 | + if offset <= start_idx and end_idx <= offset + cat_input_shape[cat_dim]: |
| 816 | + slice_copy_node.replace_input_with(cat_node, cat_input_node) |
| 817 | + set_arg(slice_copy_node, 2, "start", start_idx - offset) |
| 818 | + set_arg(slice_copy_node, 3, "end", end_idx - offset) |
829 | 819 | break
|
830 |
| - base_idx += list(cat_input_shape)[cat_dim] |
831 |
| - if cat_input_to_keep is not None: |
832 |
| - slice_copy_node.replace_input_with(input_node, cat_input_to_keep) |
| 820 | + |
| 821 | + offset += cat_input_shape[cat_dim] |
833 | 822 |
|
834 | 823 | def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
835 | 824 | self._remove_unused_cat(graph_module)
|
|
0 commit comments