Skip to content

Commit eaf383a

Browse files
authored
Add pass to convert split to many slice
Differential Revision: D61211922 Pull Request resolved: #4562
1 parent 4c06907 commit eaf383a

File tree

9 files changed

+241
-18
lines changed

9 files changed

+241
-18
lines changed

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
4343
exir_ops.edge.aten.hardtanh.default,
4444
exir_ops.edge.aten.convolution.default,
4545
exir_ops.edge.aten.div.Tensor,
46+
exir_ops.edge.aten.split_with_sizes_copy.default,
4647
exir_ops.edge.aten.full.default,
4748
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
4849
exir_ops.edge.aten.avg_pool2d.default,

backends/arm/operators/op_slice.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def define_node(
4040
shape = input_node.shape
4141
dim = dim.number
4242
end = (shape[dim] + end.number) % shape[dim]
43+
if end == 0:
44+
end = shape[dim]
4345
size = end - start.number
4446
assert size > 0
4547
assert size <= shape[dim]

backends/arm/passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
from executorch.backends.arm.passes.convert_expand_copy_to_repeat import (
1313
ConvertExpandCopyToRepeatPass,
1414
)
15+
from executorch.backends.arm.passes.convert_split_to_slice import (
16+
ConvertSplitToSlicePass,
17+
)
1518
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
1619
from executorch.exir.backend.compile_spec_schema import CompileSpec
1720
from executorch.exir.pass_manager import PassManager
@@ -28,6 +31,7 @@ def transform_to_backend_pipeline(
2831
"""Apply passes before transforming program to backend"""
2932
self.add_pass(RemoveClonePass())
3033
self.add_pass(ConvertExpandCopyToRepeatPass())
34+
self.add_pass(ConvertSplitToSlicePass())
3135
for spec in compile_spec:
3236
if spec.key == "permute_memory_format":
3337
memory_format = spec.value.decode()
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch.fx
8+
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
9+
from executorch.exir.dialects._ops import ops as exir_ops
10+
from executorch.exir.pass_base import ExportPass, PassResult
11+
12+
13+
class ConvertSplitToSlicePass(ExportPass):
14+
"""
15+
Replace a split operation with many slice operations.
16+
"""
17+
18+
split_ops = (
19+
exir_ops.edge.aten.split_with_sizes_copy.default,
20+
exir_ops.edge.aten.split_copy.Tensor,
21+
)
22+
slice = exir_ops.edge.aten.slice_copy.Tensor
23+
24+
def call(self, graph_module: torch.fx.GraphModule):
25+
graph = graph_module.graph
26+
for node in graph.nodes:
27+
if node.target not in self.split_ops:
28+
continue
29+
30+
# Get useful variables
31+
split_node = node
32+
input_node = split_node.all_input_nodes[0]
33+
output_nodes = split_node.users.copy()
34+
_, shape, _ = extract_tensor_meta(input_node.meta)
35+
rank = len(shape)
36+
split_lengths = split_node.args[1]
37+
dim = split_node.args[2] if len(split_node.args) > 2 else 0
38+
dim = (dim + rank) % rank
39+
40+
assert (
41+
sum(split_lengths) == shape[dim]
42+
), "Given split lengths don't sum up to the size of the dimension."
43+
44+
# Convert split argument 'split_lengths' to slice arguments start and end.
45+
starts = [0] * len(split_lengths)
46+
ends = [0] * len(split_lengths)
47+
start = 0
48+
end = 0
49+
for i, split_length in enumerate(split_lengths):
50+
end = start + split_length
51+
starts[i] = start
52+
ends[i] = end
53+
start = end
54+
55+
# Output nodes are of type getitem
56+
# Create one slice node for each output node with matching argumetns.
57+
with graph_module.graph.inserting_before(split_node):
58+
for output_node in output_nodes:
59+
index = output_node.args[1]
60+
slice_node = graph.create_node(
61+
"call_function",
62+
self.slice,
63+
(input_node, dim, starts[index], ends[index]),
64+
)
65+
slice_node.meta = split_node.meta.copy()
66+
slice_node.meta["val"] = slice_node.meta["val"][index]
67+
output_node.replace_input_with(split_node, slice_node)
68+
graph.eliminate_dead_code()
69+
graph_module.recompile()
70+
return PassResult(graph_module, True)

backends/arm/quantizer/arm_quantizer_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
# Utility functions for ArmQuantizer
1010
#
1111

12+
import operator
1213
from typing import Callable, cast, List
1314

1415
import torch
@@ -141,8 +142,11 @@ def is_share_obs_or_fq_op(op: Callable) -> bool:
141142
torch.ops.aten.view_copy.default,
142143
torch.ops.aten.view.default,
143144
torch.ops.aten.slice.Tensor,
145+
torch.ops.aten.split.Tensor,
146+
torch.ops.aten.split_with_sizes.default,
144147
torch.ops.aten.flatten.using_ints,
145148
torch.ops.aten.dropout.default,
149+
operator.getitem,
146150
]
147151

148152

backends/arm/test/ops/test_slice.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def forward(self, x: torch.Tensor):
3333
elif x.dim() == 3:
3434
return x[0:7, 0:1, 0:8]
3535
elif x.dim() == 4:
36-
return x[:, 2:5, 3:5, 4:5]
36+
return x[:, 2:5, 3:5, 4:10]
3737

3838
def _test_slice_tosa_MI_pipeline(
3939
self, module: torch.nn.Module, test_data: torch.Tensor

backends/arm/test/ops/test_split.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from executorch.backends.arm.quantizer.arm_quantizer import (
11+
ArmQuantizer,
12+
get_symmetric_quantization_config,
13+
)
14+
from executorch.backends.arm.test import common
15+
from executorch.backends.arm.test.tester.arm_tester import ArmTester
16+
from executorch.backends.xnnpack.test.tester.tester import Quantize
17+
from parameterized import parameterized
18+
19+
test_data_t = tuple[torch.Tensor, int | list[int], int]
20+
21+
22+
class TestSimpleSplit(unittest.TestCase):
23+
class Split(torch.nn.Module):
24+
25+
test_data: list[tuple[test_data_t]] = [
26+
((torch.rand(10), 2, 0),),
27+
((torch.rand(10, 10), 3, 1),),
28+
((torch.rand(10, 10), 4, -1),),
29+
((torch.rand(10, 15, 10), [2, 2, 11], 1),),
30+
((torch.rand(4, 4, 4, 4), 2, 0),),
31+
((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),),
32+
]
33+
34+
def forward(
35+
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
36+
):
37+
return x.split(split_size=split_size_or_sections, dim=dim)
38+
39+
class SplitWithSizes(torch.nn.Module):
40+
def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int):
41+
return x.split_with_sizes(split_sizes=split_sizes, dim=dim)
42+
43+
class SplitSingleOut(torch.nn.Module):
44+
def forward(
45+
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
46+
):
47+
return x.split(split_size=split_size_or_sections, dim=dim)[1]
48+
49+
class SplitTwoOut(torch.nn.Module):
50+
def forward(
51+
self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
52+
):
53+
return x.split(split_size=split_size_or_sections, dim=dim)[1:3]
54+
55+
def _test_split_tosa_MI_pipeline(
56+
self, module: torch.nn.Module, test_data: test_data_t
57+
):
58+
(
59+
ArmTester(
60+
module,
61+
example_inputs=test_data,
62+
compile_spec=common.get_tosa_compile_spec(),
63+
)
64+
.export()
65+
.to_edge()
66+
.check(
67+
[
68+
"executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default"
69+
]
70+
)
71+
.partition()
72+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
73+
.to_executorch()
74+
.run_method_and_compare_outputs(inputs=test_data)
75+
)
76+
77+
def _test_split_tosa_BI_pipeline(
78+
self, module: torch.nn.Module, test_data: test_data_t
79+
):
80+
81+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
82+
(
83+
ArmTester(
84+
module,
85+
example_inputs=test_data,
86+
compile_spec=common.get_tosa_compile_spec(),
87+
)
88+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
89+
.export()
90+
.to_edge()
91+
.partition()
92+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
93+
.to_executorch()
94+
.run_method_and_compare_outputs(inputs=test_data, qtol=1)
95+
)
96+
97+
def _test_split_u55_BI_pipeline(
98+
self, module: torch.nn.Module, test_data: test_data_t
99+
):
100+
quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
101+
(
102+
ArmTester(
103+
module,
104+
example_inputs=test_data,
105+
compile_spec=common.get_u55_compile_spec(),
106+
)
107+
.quantize(Quantize(quantizer, get_symmetric_quantization_config()))
108+
.export()
109+
.check(["torch.ops.aten.split.Tensor"])
110+
.to_edge()
111+
.partition()
112+
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
113+
.to_executorch()
114+
)
115+
116+
@parameterized.expand(Split.test_data)
117+
def test_split_tosa_MI(self, test_data: test_data_t):
118+
self._test_split_tosa_MI_pipeline(self.Split(), test_data)
119+
120+
@parameterized.expand([Split.test_data[3], Split.test_data[5]])
121+
def test_split_with_sizes_tosa_MI(self, test_data: test_data_t):
122+
assert isinstance(test_data[1], list)
123+
self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data)
124+
125+
@parameterized.expand(Split.test_data)
126+
def test_split_n_out_tosa_MI(self, test_data: test_data_t):
127+
self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data)
128+
self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data)
129+
130+
@parameterized.expand(Split.test_data)
131+
def test_split_tosa_BI(self, test_data: test_data_t):
132+
self._test_split_tosa_BI_pipeline(self.Split(), test_data)
133+
134+
# Fails during Vela compilation when trying to use a Tuple as a Named tuple,
135+
# Could be Vela Issue, wait until Regor.
136+
@parameterized.expand(Split.test_data)
137+
@unittest.expectedFailure
138+
def test_split_u55_BI(self, test_data: test_data_t):
139+
self._test_split_u55_BI_pipeline(self.Split(), test_data)

backends/arm/test/runner_utils.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def set_timeout(self, timeout: int):
202202
def run_corstone300(
203203
self,
204204
inputs: Tuple[torch.Tensor],
205-
) -> torch.Tensor:
205+
) -> list[torch.Tensor]:
206206

207207
assert (
208208
self._has_init_run
@@ -268,12 +268,12 @@ def run_corstone300(
268268

269269
tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
270270
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(inputs[0].shape)
271-
return tosa_ref_output
271+
return [tosa_ref_output]
272272

273273
def run_tosa_ref_model(
274274
self,
275275
inputs: Tuple[torch.Tensor],
276-
) -> torch.Tensor:
276+
) -> list[torch.Tensor]:
277277
"""
278278
Run TOSA reference model using the tosa_refence_model program.
279279
@@ -369,23 +369,26 @@ def run_tosa_ref_model(
369369
# Load desc.json, just to get the name of the output file above
370370
with open(desc_file_path) as f:
371371
desc_json = json.load(f)
372-
ofm_file_npy = os.path.join(self.intermediate_path, desc_json["ofm_file"][0])
373372

374-
# Load the output file (OFM) and return it as a numpy array
375-
tosa_ref_output = np.load(ofm_file_npy)
373+
tosa_ref_outputs = []
374+
for ofm_file in desc_json["ofm_file"]:
375+
ofm_file_npy = os.path.join(self.intermediate_path, ofm_file)
376376

377-
if self.is_quantized:
378-
# Need to dequant back to FP32 for comparison with torch output
379-
quant_param = self.qp_output
380-
assert (
381-
quant_param is not None
382-
), "There are no quantization parameters, check output parameters"
383-
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
377+
# Load the output file (OFM) and return it as a numpy array
378+
tosa_ref_output = np.load(ofm_file_npy)
384379

385-
# tosa_output is a numpy array, convert to torch tensor for comparison
386-
tosa_ref_output = torch.from_numpy(tosa_ref_output.astype("float32"))
380+
if self.is_quantized:
381+
# Need to dequant back to FP32 for comparison with torch output
382+
quant_param = self.qp_output
383+
assert (
384+
quant_param is not None
385+
), "There are no quantization parameters, check output parameters"
386+
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
387387

388-
return tosa_ref_output
388+
# tosa_output is a numpy array, convert to torch tensor for comparison
389+
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32")))
390+
391+
return tosa_ref_outputs
389392

390393

391394
def prep_data_for_save(

backends/arm/test/tester/arm_tester.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def run_method_and_compare_outputs(
260260
print(f"Run {run_iteration} with input shapes: {input_shapes}")
261261

262262
reference_output = reference_stage.run_artifact(reference_input)
263-
test_output = (test_stage.run_artifact(test_input),)
263+
test_output = tuple(test_stage.run_artifact(test_input))
264264
if is_nhwc:
265265
test_output = self.transpose_data_format(test_output, "NCHW")
266266

0 commit comments

Comments
 (0)