Skip to content

Commit 2a4256b

Browse files
Improve data format handling in Arm backend (#7588)
- Updates the Vela version to support TRANSPOSE for EthosU55. - Add pass to decompose SELECT into SLICE + SQUEEZE to insert tranposes properly - Updates affected unittests These changes removes the need for the permute_memory_to_nhwc flag and removes a lot of expected failures.
1 parent 3f9324c commit 2a4256b

37 files changed

+208
-410
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
)
2929
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
3030
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
31+
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
3132
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
3233
DecomposeSoftmaxesPass,
3334
)
@@ -62,7 +63,6 @@
6263
)
6364
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
6465
from executorch.exir import ExportedProgram
65-
from executorch.exir.backend.compile_spec_schema import CompileSpec
6666
from executorch.exir.dialects._ops import ops as exir_ops
6767
from executorch.exir.pass_manager import PassManager
6868

@@ -72,9 +72,7 @@ class ArmPassManager(PassManager):
7272
def _transform(self, graph_module: torch.fx.GraphModule):
7373
return self(graph_module).graph_module
7474

75-
def transform_to_backend_pipeline(
76-
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
77-
):
75+
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
7876
"""Apply passes before transforming program to backend"""
7977
self.add_pass(FuseQuantizedActivationPass())
8078
self.add_pass(DecomposeLinearPass())
@@ -137,11 +135,8 @@ def transform_to_backend_pipeline(
137135
self.add_pass(KeepDimsFalseToSqueezePass())
138136
self.add_pass(Conv1dUnsqueezePass(exported_program))
139137
self.add_pass(DecomposeSoftmaxesPass())
140-
for spec in compile_spec:
141-
if spec.key == "permute_memory_format":
142-
memory_format = spec.value.decode()
143-
if memory_format == "nhwc":
144-
self.add_pass(AnnotateChannelsLastDimOrder())
138+
self.add_pass(DecomposeSelectPass())
139+
self.add_pass(AnnotateChannelsLastDimOrder())
145140

146141
return self._transform(exported_program.graph_module)
147142

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.backends.arm._passes.arm_pass_utils import create_node
11+
from executorch.exir.dialects._ops import ops as exir_ops
12+
from executorch.exir.pass_base import ExportPass, PassResult
13+
14+
15+
class DecomposeSelectPass(ExportPass):
16+
"""
17+
This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1)
18+
"""
19+
20+
def call(self, graph_module: torch.fx.GraphModule):
21+
for node in graph_module.graph.nodes:
22+
23+
if node.op != "call_function":
24+
continue
25+
26+
if node.target in (
27+
exir_ops.edge.aten.select.int,
28+
exir_ops.edge.aten.select_copy.int,
29+
):
30+
slice_op = exir_ops.edge.aten.slice_copy.Tensor
31+
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
32+
else:
33+
continue
34+
35+
input_node, dim, index = node.args
36+
37+
rank = len(input_node.meta["val"].size())
38+
dim = dim % rank if dim < 0 else dim
39+
index = index % rank if index < 0 else index
40+
dim_list = list(range(rank))
41+
42+
with graph_module.graph.inserting_before(node):
43+
slice_node = create_node(
44+
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
45+
)
46+
squeeze_node = create_node(
47+
graph_module.graph, squeeze_op, (slice_node, dim_list)
48+
)
49+
50+
node.replace_all_uses_with(squeeze_node)
51+
graph_module.graph.erase_node(node)
52+
53+
graph_module.graph.eliminate_dead_code()
54+
graph_module.recompile()
55+
graph_module = super().call(graph_module).graph_module
56+
return PassResult(graph_module, True)

backends/arm/arm_backend.py

Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -49,8 +49,6 @@ def __init__(self):
4949
self.compiler_flags = []
5050
self.output_format = None
5151
self.path_for_intermediates = None
52-
# TODO MLETORCH-265 Remove permute_nhwc flag
53-
self.permute_nhwc = False
5452
self.quantize_io = False
5553
self.tosa_version = None
5654
self.input_order = None
@@ -118,16 +116,6 @@ def dump_intermediate_artifacts_to(
118116
self.path_for_intermediates = output_path
119117
return self
120118

121-
def set_permute_memory_format(
122-
self, set_nhwc_permutation: bool = True
123-
) -> "ArmCompileSpecBuilder":
124-
"""
125-
Permute to channel last in compiler and runtime. Compilation and
126-
runtime will convert rank 4 inputs to channel last for each sub-graph.
127-
"""
128-
self.permute_nhwc = set_nhwc_permutation
129-
return self
130-
131119
def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
132120
"""
133121
Quantization of inputs and dequantization of outputs for cases where
@@ -170,11 +158,6 @@ def build(self) -> List[CompileSpec]:
170158
CompileSpec("debug_artifact_path", self.path_for_intermediates.encode())
171159
)
172160

173-
if self.permute_nhwc:
174-
self.compile_spec.append(
175-
CompileSpec("permute_memory_format", "nhwc".encode())
176-
)
177-
178161
if self.input_order:
179162
self.compile_spec.append(
180163
CompileSpec(
@@ -188,13 +171,6 @@ def build(self) -> List[CompileSpec]:
188171
return self.compile_spec
189172

190173

191-
def is_permute_memory(compile_spec: List[CompileSpec]) -> bool:
192-
for spec in compile_spec:
193-
if spec.key == "permute_memory_format":
194-
return spec.value.decode() == "nhwc"
195-
return False
196-
197-
198174
def is_tosa(compile_spec: List[CompileSpec]) -> bool:
199175
for spec in compile_spec:
200176
if spec.key == "output_format":
@@ -264,7 +240,7 @@ def preprocess( # noqa: C901
264240
# const data directly. Path created and data written only in debug builds.
265241
tosa_graph = ts.TosaSerializer(artifact_path)
266242
graph_module = ArmPassManager().transform_to_backend_pipeline(
267-
exported_program=edge_program, compile_spec=compile_spec
243+
exported_program=edge_program
268244
)
269245

270246
node_visitors = get_node_visitors(edge_program, tosa_spec)

backends/arm/operators/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
op_repeat,
3131
op_rshift,
3232
op_rsqrt,
33-
op_select,
3433
op_sigmoid,
3534
op_slice,
3635
op_squeeze,

backends/arm/operators/op_select.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

backends/arm/runtime/ArmBackendEthosU.cpp

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 Arm Limited and/or its affiliates.
2+
* Copyright 2023-2025 Arm Limited and/or its affiliates.
33
*
44
* This source code is licensed under the BSD-style license found in the
55
* LICENSE file in the root directory of this source tree.
@@ -76,7 +76,6 @@ namespace arm {
7676

7777
typedef struct {
7878
FreeableBuffer* processed;
79-
bool permuted_io_flag;
8079
} ExecutionHandle;
8180

8281
extern "C" {
@@ -125,14 +124,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
125124
ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
126125
handle->processed = processed;
127126

128-
handle->permuted_io_flag = false;
129-
for (auto& compile_spec : compile_specs) {
130-
if (0 == std::strcmp(compile_spec.key, "permute_memory_format") &&
131-
0 == std::memcmp(compile_spec.value.buffer, "nhwc", 4)) {
132-
handle->permuted_io_flag = true;
133-
}
134-
}
135-
136127
// Return the same buffer we were passed - this data will be
137128
// executed directly
138129
return handle;
@@ -225,11 +216,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
225216
// which require permutation.
226217
bool permuted_input_shape;
227218
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
228-
i,
229-
tensor_in,
230-
&handles.inputs->io[i],
231-
execution_handle->permuted_io_flag,
232-
&permuted_input_shape));
219+
i, tensor_in, &handles.inputs->io[i], &permuted_input_shape));
233220
bool both_char = tensor_in.scalar_type() == ScalarType::Char and
234221
handles.inputs->io[i].elem_size == 1;
235222
bool both_int = tensor_in.scalar_type() == ScalarType::Int and
@@ -330,11 +317,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
330317

331318
bool permuted_output_shape;
332319
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
333-
i,
334-
tensor_out,
335-
&handles.outputs->io[i],
336-
execution_handle->permuted_io_flag,
337-
&permuted_output_shape));
320+
i, tensor_out, &handles.outputs->io[i], &permuted_output_shape));
338321
if (tensor_out.scalar_type() == ScalarType::Char and
339322
permuted_output_shape) {
340323
EXECUTORCH_PROF_SCOPE(
@@ -395,7 +378,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
395378
int index,
396379
const executorch::aten::Tensor tensor,
397380
VelaIO* io,
398-
bool permuted_io_flag,
399381
bool* is_permuted) const {
400382
bool permuted_shape = false;
401383

@@ -409,12 +391,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
409391
if (permuted_shape) {
410392
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
411393
}
412-
if (permuted_io_flag != permuted_shape) {
413-
ET_LOG(
414-
Error,
415-
"Permute compile flag and permuted input/output don't agree");
416-
return Error::InvalidProgram;
417-
}
418394
}
419395
*is_permuted = permuted_shape;
420396
return Error::Ok;

0 commit comments

Comments
 (0)