|
8 | 8 | from executorch.backends.arm.tosa_mapping import extract_tensor_meta
|
9 | 9 | from executorch.exir.dialects._ops import ops as exir_ops
|
10 | 10 | from executorch.exir.pass_base import ExportPass, PassResult
|
11 |
| -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions |
12 | 11 |
|
13 | 12 |
|
14 | 13 | class ConvertSplitToSlicePass(ExportPass):
|
15 | 14 | """
|
16 | 15 | Replace a split operation with many slice operations.
|
17 | 16 | """
|
18 | 17 |
|
19 |
| - split_copy = exir_ops.edge.aten.split_with_sizes_copy.default |
| 18 | + split_ops = ( |
| 19 | + exir_ops.edge.aten.split_with_sizes_copy.default, |
| 20 | + exir_ops.edge.aten.split_copy.Tensor, |
| 21 | + ) |
20 | 22 | slice = exir_ops.edge.aten.slice_copy.Tensor
|
21 |
| - patterns = [{split_copy: 1}] |
22 | 23 |
|
23 | 24 | def call(self, graph_module: torch.fx.GraphModule):
|
24 | 25 | graph = graph_module.graph
|
25 |
| - partitions = get_source_partitions( |
26 |
| - graph, |
27 |
| - [torch.split_with_sizes, torch.Tensor.split, "split", "split_with_sizes"], |
28 |
| - ) |
29 |
| - for _, src_partitions in partitions.items(): |
30 |
| - for src_partition in src_partitions: |
| 26 | + for node in graph.nodes: |
| 27 | + if node.target not in self.split_ops: |
| 28 | + continue |
31 | 29 |
|
32 |
| - # Get useful variables |
33 |
| - split_node = src_partition.nodes[0] |
34 |
| - input_node = split_node.all_input_nodes[0] |
35 |
| - output_nodes = split_node.users.copy() |
36 |
| - _, shape, _ = extract_tensor_meta(input_node.meta) |
37 |
| - rank = len(shape) |
38 |
| - split_lengths = split_node.args[1] |
39 |
| - dim = split_node.args[2] if len(split_node.args) > 2 else 0 |
40 |
| - dim = (dim + rank) % rank |
| 30 | + # Get useful variables |
| 31 | + split_node = node |
| 32 | + input_node = split_node.all_input_nodes[0] |
| 33 | + output_nodes = split_node.users.copy() |
| 34 | + _, shape, _ = extract_tensor_meta(input_node.meta) |
| 35 | + rank = len(shape) |
| 36 | + split_lengths = split_node.args[1] |
| 37 | + dim = split_node.args[2] if len(split_node.args) > 2 else 0 |
| 38 | + dim = (dim + rank) % rank |
41 | 39 |
|
42 |
| - assert ( |
43 |
| - sum(split_lengths) == shape[dim] |
44 |
| - ), "Given split lengths don't sum up to the size of the dimension." |
| 40 | + assert ( |
| 41 | + sum(split_lengths) == shape[dim] |
| 42 | + ), "Given split lengths don't sum up to the size of the dimension." |
45 | 43 |
|
46 |
| - # Convert split argument 'split_lengths' to slice arguments start and end. |
47 |
| - starts = [0] * len(split_lengths) |
48 |
| - ends = [0] * len(split_lengths) |
49 |
| - start = 0 |
50 |
| - end = 0 |
51 |
| - for i, split_length in enumerate(split_lengths): |
52 |
| - end = start + split_length |
53 |
| - starts[i] = start |
54 |
| - ends[i] = end |
55 |
| - start = end |
| 44 | + # Convert split argument 'split_lengths' to slice arguments start and end. |
| 45 | + starts = [0] * len(split_lengths) |
| 46 | + ends = [0] * len(split_lengths) |
| 47 | + start = 0 |
| 48 | + end = 0 |
| 49 | + for i, split_length in enumerate(split_lengths): |
| 50 | + end = start + split_length |
| 51 | + starts[i] = start |
| 52 | + ends[i] = end |
| 53 | + start = end |
56 | 54 |
|
57 |
| - # Output nodes are of type getitem |
58 |
| - # Create one slice node for each output node with matching argumetns. |
59 |
| - with graph_module.graph.inserting_before(split_node): |
60 |
| - for output_node in output_nodes: |
61 |
| - index = output_node.args[1] |
62 |
| - slice_node = graph.create_node( |
63 |
| - "call_function", |
64 |
| - self.slice, |
65 |
| - (input_node, dim, starts[index], ends[index]), |
66 |
| - ) |
67 |
| - slice_node.meta = split_node.meta.copy() |
68 |
| - slice_node.meta["val"] = slice_node.meta["val"][index] |
69 |
| - output_node.replace_input_with(split_node, slice_node) |
| 55 | + # Output nodes are of type getitem |
| 56 | + # Create one slice node for each output node with matching argumetns. |
| 57 | + with graph_module.graph.inserting_before(split_node): |
| 58 | + for output_node in output_nodes: |
| 59 | + index = output_node.args[1] |
| 60 | + slice_node = graph.create_node( |
| 61 | + "call_function", |
| 62 | + self.slice, |
| 63 | + (input_node, dim, starts[index], ends[index]), |
| 64 | + ) |
| 65 | + slice_node.meta = split_node.meta.copy() |
| 66 | + slice_node.meta["val"] = slice_node.meta["val"][index] |
| 67 | + output_node.replace_input_with(split_node, slice_node) |
70 | 68 | graph.eliminate_dead_code()
|
71 | 69 | graph_module.recompile()
|
72 | 70 | return PassResult(graph_module, True)
|
0 commit comments