Skip to content

Commit 79312e0

Browse files
committed
Use node target rather than get_source_partition
This is a better way to find all aten split nodes, since get_source_partion misses split nodes added during lowering. Signed-off-by: Erik Lundell <[email protected]> Change-Id: Iff850850341240d22e1012511de23b3efdabf1bb
1 parent bfeab38 commit 79312e0

File tree

1 file changed

+42
-44
lines changed

1 file changed

+42
-44
lines changed

backends/arm/passes/convert_split_to_slice.py

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -8,65 +8,63 @@
88
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
99
from executorch.exir.dialects._ops import ops as exir_ops
1010
from executorch.exir.pass_base import ExportPass, PassResult
11-
from torch.fx.passes.utils.source_matcher_utils import get_source_partitions
1211

1312

1413
class ConvertSplitToSlicePass(ExportPass):
1514
"""
1615
Replace a split operation with many slice operations.
1716
"""
1817

19-
split_copy = exir_ops.edge.aten.split_with_sizes_copy.default
18+
split_ops = (
19+
exir_ops.edge.aten.split_with_sizes_copy.default,
20+
exir_ops.edge.aten.split_copy.Tensor,
21+
)
2022
slice = exir_ops.edge.aten.slice_copy.Tensor
21-
patterns = [{split_copy: 1}]
2223

2324
def call(self, graph_module: torch.fx.GraphModule):
2425
graph = graph_module.graph
25-
partitions = get_source_partitions(
26-
graph,
27-
[torch.split_with_sizes, torch.Tensor.split, "split", "split_with_sizes"],
28-
)
29-
for _, src_partitions in partitions.items():
30-
for src_partition in src_partitions:
26+
for node in graph.nodes:
27+
if node.target not in self.split_ops:
28+
continue
3129

32-
# Get useful variables
33-
split_node = src_partition.nodes[0]
34-
input_node = split_node.all_input_nodes[0]
35-
output_nodes = split_node.users.copy()
36-
_, shape, _ = extract_tensor_meta(input_node.meta)
37-
rank = len(shape)
38-
split_lengths = split_node.args[1]
39-
dim = split_node.args[2] if len(split_node.args) > 2 else 0
40-
dim = (dim + rank) % rank
30+
# Get useful variables
31+
split_node = node
32+
input_node = split_node.all_input_nodes[0]
33+
output_nodes = split_node.users.copy()
34+
_, shape, _ = extract_tensor_meta(input_node.meta)
35+
rank = len(shape)
36+
split_lengths = split_node.args[1]
37+
dim = split_node.args[2] if len(split_node.args) > 2 else 0
38+
dim = (dim + rank) % rank
4139

42-
assert (
43-
sum(split_lengths) == shape[dim]
44-
), "Given split lengths don't sum up to the size of the dimension."
40+
assert (
41+
sum(split_lengths) == shape[dim]
42+
), "Given split lengths don't sum up to the size of the dimension."
4543

46-
# Convert split argument 'split_lengths' to slice arguments start and end.
47-
starts = [0] * len(split_lengths)
48-
ends = [0] * len(split_lengths)
49-
start = 0
50-
end = 0
51-
for i, split_length in enumerate(split_lengths):
52-
end = start + split_length
53-
starts[i] = start
54-
ends[i] = end
55-
start = end
44+
# Convert split argument 'split_lengths' to slice arguments start and end.
45+
starts = [0] * len(split_lengths)
46+
ends = [0] * len(split_lengths)
47+
start = 0
48+
end = 0
49+
for i, split_length in enumerate(split_lengths):
50+
end = start + split_length
51+
starts[i] = start
52+
ends[i] = end
53+
start = end
5654

57-
# Output nodes are of type getitem
58-
# Create one slice node for each output node with matching argumetns.
59-
with graph_module.graph.inserting_before(split_node):
60-
for output_node in output_nodes:
61-
index = output_node.args[1]
62-
slice_node = graph.create_node(
63-
"call_function",
64-
self.slice,
65-
(input_node, dim, starts[index], ends[index]),
66-
)
67-
slice_node.meta = split_node.meta.copy()
68-
slice_node.meta["val"] = slice_node.meta["val"][index]
69-
output_node.replace_input_with(split_node, slice_node)
55+
# Output nodes are of type getitem
56+
# Create one slice node for each output node with matching argumetns.
57+
with graph_module.graph.inserting_before(split_node):
58+
for output_node in output_nodes:
59+
index = output_node.args[1]
60+
slice_node = graph.create_node(
61+
"call_function",
62+
self.slice,
63+
(input_node, dim, starts[index], ends[index]),
64+
)
65+
slice_node.meta = split_node.meta.copy()
66+
slice_node.meta["val"] = slice_node.meta["val"][index]
67+
output_node.replace_input_with(split_node, slice_node)
7068
graph.eliminate_dead_code()
7169
graph_module.recompile()
7270
return PassResult(graph_module, True)

0 commit comments

Comments
 (0)