Skip to content

Improve data format handling in Arm backend #7588

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
)
Expand Down Expand Up @@ -62,7 +63,6 @@
)
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_manager import PassManager

Expand All @@ -72,9 +72,7 @@ class ArmPassManager(PassManager):
def _transform(self, graph_module: torch.fx.GraphModule):
return self(graph_module).graph_module

def transform_to_backend_pipeline(
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
):
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
"""Apply passes before transforming program to backend"""
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(DecomposeLinearPass())
Expand Down Expand Up @@ -137,11 +135,8 @@ def transform_to_backend_pipeline(
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
for spec in compile_spec:
if spec.key == "permute_memory_format":
memory_format = spec.value.decode()
if memory_format == "nhwc":
self.add_pass(AnnotateChannelsLastDimOrder())
self.add_pass(DecomposeSelectPass())
self.add_pass(AnnotateChannelsLastDimOrder())

return self._transform(exported_program.graph_module)

Expand Down
56 changes: 56 additions & 0 deletions backends/arm/_passes/decompose_select.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import torch
from executorch.backends.arm._passes.arm_pass_utils import create_node
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass, PassResult


class DecomposeSelectPass(ExportPass):
"""
This pass decomposes select into slice + squeeze to ensure that Aten and TOSA outputs has the same rank (input rank -1)
"""

def call(self, graph_module: torch.fx.GraphModule):
for node in graph_module.graph.nodes:

if node.op != "call_function":
continue

if node.target in (
exir_ops.edge.aten.select.int,
exir_ops.edge.aten.select_copy.int,
):
slice_op = exir_ops.edge.aten.slice_copy.Tensor
squeeze_op = exir_ops.edge.aten.squeeze_copy.dims
else:
continue

input_node, dim, index = node.args

rank = len(input_node.meta["val"].size())
dim = dim % rank if dim < 0 else dim
index = index % rank if index < 0 else index
dim_list = list(range(rank))

with graph_module.graph.inserting_before(node):
slice_node = create_node(
graph_module.graph, slice_op, (input_node, dim, index, index + 1)
)
squeeze_node = create_node(
graph_module.graph, squeeze_op, (slice_node, dim_list)
)

node.replace_all_uses_with(squeeze_node)
graph_module.graph.erase_node(node)

graph_module.graph.eliminate_dead_code()
graph_module.recompile()
graph_module = super().call(graph_module).graph_module
return PassResult(graph_module, True)
28 changes: 2 additions & 26 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -49,8 +49,6 @@ def __init__(self):
self.compiler_flags = []
self.output_format = None
self.path_for_intermediates = None
# TODO MLETORCH-265 Remove permute_nhwc flag
self.permute_nhwc = False
self.quantize_io = False
self.tosa_version = None
self.input_order = None
Expand Down Expand Up @@ -118,16 +116,6 @@ def dump_intermediate_artifacts_to(
self.path_for_intermediates = output_path
return self

def set_permute_memory_format(
self, set_nhwc_permutation: bool = True
) -> "ArmCompileSpecBuilder":
"""
Permute to channel last in compiler and runtime. Compilation and
runtime will convert rank 4 inputs to channel last for each sub-graph.
"""
self.permute_nhwc = set_nhwc_permutation
return self

def set_quantize_io(self, quantize_io: bool = False) -> "ArmCompileSpecBuilder":
"""
Quantization of inputs and dequantization of outputs for cases where
Expand Down Expand Up @@ -170,11 +158,6 @@ def build(self) -> List[CompileSpec]:
CompileSpec("debug_artifact_path", self.path_for_intermediates.encode())
)

if self.permute_nhwc:
self.compile_spec.append(
CompileSpec("permute_memory_format", "nhwc".encode())
)

if self.input_order:
self.compile_spec.append(
CompileSpec(
Expand All @@ -188,13 +171,6 @@ def build(self) -> List[CompileSpec]:
return self.compile_spec


def is_permute_memory(compile_spec: List[CompileSpec]) -> bool:
for spec in compile_spec:
if spec.key == "permute_memory_format":
return spec.value.decode() == "nhwc"
return False


def is_tosa(compile_spec: List[CompileSpec]) -> bool:
for spec in compile_spec:
if spec.key == "output_format":
Expand Down Expand Up @@ -264,7 +240,7 @@ def preprocess( # noqa: C901
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager().transform_to_backend_pipeline(
exported_program=edge_program, compile_spec=compile_spec
exported_program=edge_program
)

node_visitors = get_node_visitors(edge_program, tosa_spec)
Expand Down
1 change: 0 additions & 1 deletion backends/arm/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
op_repeat,
op_rshift,
op_rsqrt,
op_select,
op_sigmoid,
op_slice,
op_squeeze,
Expand Down
68 changes: 0 additions & 68 deletions backends/arm/operators/op_select.py

This file was deleted.

30 changes: 3 additions & 27 deletions backends/arm/runtime/ArmBackendEthosU.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 Arm Limited and/or its affiliates.
* Copyright 2023-2025 Arm Limited and/or its affiliates.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
Expand Down Expand Up @@ -76,7 +76,6 @@ namespace arm {

typedef struct {
FreeableBuffer* processed;
bool permuted_io_flag;
} ExecutionHandle;

extern "C" {
Expand Down Expand Up @@ -125,14 +124,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR(allocator, ExecutionHandle);
handle->processed = processed;

handle->permuted_io_flag = false;
for (auto& compile_spec : compile_specs) {
if (0 == std::strcmp(compile_spec.key, "permute_memory_format") &&
0 == std::memcmp(compile_spec.value.buffer, "nhwc", 4)) {
handle->permuted_io_flag = true;
}
}

// Return the same buffer we were passed - this data will be
// executed directly
return handle;
Expand Down Expand Up @@ -225,11 +216,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
// which require permutation.
bool permuted_input_shape;
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
i,
tensor_in,
&handles.inputs->io[i],
execution_handle->permuted_io_flag,
&permuted_input_shape));
i, tensor_in, &handles.inputs->io[i], &permuted_input_shape));
bool both_char = tensor_in.scalar_type() == ScalarType::Char and
handles.inputs->io[i].elem_size == 1;
bool both_int = tensor_in.scalar_type() == ScalarType::Int and
Expand Down Expand Up @@ -330,11 +317,7 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {

bool permuted_output_shape;
ET_CHECK_OK_OR_RETURN_ERROR(check_requires_permute(
i,
tensor_out,
&handles.outputs->io[i],
execution_handle->permuted_io_flag,
&permuted_output_shape));
i, tensor_out, &handles.outputs->io[i], &permuted_output_shape));
if (tensor_out.scalar_type() == ScalarType::Char and
permuted_output_shape) {
EXECUTORCH_PROF_SCOPE(
Expand Down Expand Up @@ -395,7 +378,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
int index,
const executorch::aten::Tensor tensor,
VelaIO* io,
bool permuted_io_flag,
bool* is_permuted) const {
bool permuted_shape = false;

Expand All @@ -409,12 +391,6 @@ class ArmBackend final : public ::executorch::runtime::BackendInterface {
if (permuted_shape) {
ET_LOG(Debug, "Tensor input/output %d will be permuted", index);
}
if (permuted_io_flag != permuted_shape) {
ET_LOG(
Error,
"Permute compile flag and permuted input/output don't agree");
return Error::InvalidProgram;
}
}
*is_permuted = permuted_shape;
return Error::Ok;
Expand Down
Loading
Loading