Skip to content

Make ArmPassManager aware of TosaSpecification #7668

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 63 additions & 52 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# pyre-unsafe

import torch
from executorch.backends.arm._passes.annotate_channels_last_dim_order_pass import (
AnnotateChannelsLastDimOrder,
)
Expand Down Expand Up @@ -47,7 +46,7 @@
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
ConvertMeanDimToAveragePool,
ConvertMeanDimToAveragePoolPass,
)
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
Expand All @@ -61,86 +60,98 @@
from executorch.backends.arm._passes.unsqueeze_scalar_placeholders_pass import (
UnsqueezeScalarPlaceholdersPass,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.xnnpack._passes.remove_getitem_op import RemoveGetItemPass
from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_manager import PassManager
from torch.fx import GraphModule


class ArmPassManager(PassManager):

def _transform(self, graph_module: torch.fx.GraphModule):
def __init__(self, tosa_spec: TosaSpecification) -> None:
self.tosa_spec = tosa_spec
super().__init__()

def _transform(self, graph_module: GraphModule):
return self(graph_module).graph_module

def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
"""Apply passes before transforming program to backend"""
def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(ConvertMeanDimToAveragePoolPass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(FoldAndAnnotateQParamsPass())
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))

self.add_pass(RemoveClonePass())
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSelectPass())

self.add_pass(AnnotateChannelsLastDimOrder())

return self._transform(exported_program.graph_module)

def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModule:

self.add_pass(FuseQuantizedActivationPass())
self.add_pass(RemoveGetItemPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
self.add_pass(DecomposeLinearPass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
self.add_pass(ConvertMeanDimToAveragePool())
self.add_pass(DecomposeMeanDimPass())
self.add_pass(ConvertSplitToSlicePass())
self.add_pass(ConvertMmToBmmPass())
# TODO MLETORCH-558
self.add_pass(ConvertMeanDimToAveragePoolPass())
self.add_pass(DecomposeDivPass())
self.add_pass(DecomposeSoftmaxesPass())

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(
FoldAndAnnotateQParamsPass(
[
exir_ops.edge.aten.minimum.default,
exir_ops.edge.aten.maximum.default,
exir_ops.edge.aten.add.Tensor,
exir_ops.edge.aten.avg_pool2d.default,
exir_ops.edge.aten.bmm.default,
exir_ops.edge.aten.cat.default,
exir_ops.edge.aten.convolution.default,
exir_ops.edge.aten.clone.default,
exir_ops.edge.aten.exp.default,
exir_ops.edge.aten.expand_copy.default,
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.hardtanh.default,
exir_ops.edge.aten.log.default,
exir_ops.edge.aten.max_pool2d.default,
exir_ops.edge.aten.mul.Tensor,
exir_ops.edge.aten.permute_copy.default,
exir_ops.edge.aten.reciprocal.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten.repeat.default,
exir_ops.edge.aten.rsqrt.default,
exir_ops.edge.aten.select_copy.int,
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.slice_copy.Tensor,
exir_ops.edge.aten.squeeze_copy.dims,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
exir_ops.edge.aten.tanh.default,
exir_ops.edge.aten.unsqueeze_copy.default,
exir_ops.edge.aten.upsample_nearest2d.vec,
exir_ops.edge.aten.view_copy.default,
]
)
)
self.add_pass(FoldAndAnnotateQParamsPass())
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))

self.add_pass(RemoveClonePass())
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(ConvertExpandCopyToRepeatPass())
self.add_pass(UnsqueezeBeforeRepeatPass())
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(UnsqueezeScalarPlaceholdersPass(exported_program))
self.add_pass(SizeAdjustConv2DPass())
self.add_pass(RemoveClonePass())
self.add_pass(CastInt64ToInt32Pass(exported_program))
self.add_pass(MatchArgRanksPass(exported_program))
self.add_pass(DecomposeDivPass())
self.add_pass(KeepDimsFalseToSqueezePass())
self.add_pass(Conv1dUnsqueezePass(exported_program))
self.add_pass(DecomposeSoftmaxesPass())
self.add_pass(DecomposeSelectPass())

self.add_pass(AnnotateChannelsLastDimOrder())

return self._transform(exported_program.graph_module)

def transform_for_annotation_pipeline(self, graph_module: torch.fx.GraphModule):
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
"""Apply passes before transforming program to backend"""
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
return self._tosa_080_BI_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
return self._tosa_080_MI_pipeline(exported_program)
else:
raise NotImplementedError(
f"No pass pipeline implemented for {self.tosa_spec=}"
)

def transform_for_annotation_pipeline(self, graph_module: GraphModule):
self.add_pass(ScalarsToAttributePass())
self.add_pass(DecomposeLayerNormPass())
self.add_pass(DecomposeVarPass())
Expand Down
6 changes: 5 additions & 1 deletion backends/arm/_passes/cast_int64_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -17,6 +17,10 @@


class CastInt64ToInt32Pass(ExportPass):
"""
Cast int64 buffers to int32 if the int64 data is in int32 range.
"""

def __init__(self, exported_program: torch.export.ExportedProgram):
super(CastInt64ToInt32Pass, self).__init__()
self.exported_program = exported_program
Expand Down
13 changes: 6 additions & 7 deletions backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import copy

from typing import cast, Dict, Iterable, Set, Tuple
from typing import cast, Dict, Set, Tuple

from executorch.backends.arm.tosa_quant_utils import QuantArgs

Expand Down Expand Up @@ -55,7 +55,7 @@ def get_output_qparams(node: Node) -> dict[int, QuantArgs]:
class FoldAndAnnotateQParamsPass(ExportPass):
"""
A pass that walks the graph and removes any DQ and Q nodes before and after the target
node in the supplied list of operators.
node.
The quantization parameters from the DQ/Q nodes are stored as meta values to be
accessible for later lowering and serialization passes.
The assumption is that the quantization annotatation adds DQ nodes for all tensor
Expand All @@ -82,9 +82,8 @@ class FoldAndAnnotateQParamsPass(ExportPass):

"""

def __init__(self, targeted_ops: Iterable[EdgeOpOverload]) -> None:
def __init__(self) -> None:
super().__init__()
self.targeted_ops = targeted_ops

def fold_and_annotate_arg(
self, graph_module: GraphModule, node: Node, arg_list: list[Node], i: int
Expand Down Expand Up @@ -131,7 +130,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.op != "call_function" or n.target not in self.targeted_ops:
if n.op != "call_function":
continue

# Make sure we haven't already set qparams meta information on the node
Expand Down Expand Up @@ -180,7 +179,7 @@ class QuantizeFullArgument(ExportPass):

def call(self, graph_module: GraphModule) -> PassResult:
modified = False
# Loop over the graph nodes and find any node in the 'targeted_ops' list.
# Loop over the graph nodes and find full.default nodes.
for n in graph_module.graph.nodes:
n = cast(Node, n)
if n.target != exir_ops.edge.aten.full.default:
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/meandim_to_averagepool_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -16,7 +16,7 @@
Argument = Any


class ConvertMeanDimToAveragePool(ExportPass):
class ConvertMeanDimToAveragePoolPass(ExportPass):
"""
Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation.
"""
Expand Down
3 changes: 2 additions & 1 deletion backends/arm/_passes/remove_clone_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -11,6 +11,7 @@


class RemoveClonePass(ExportPass):
"""Remove all clones from graph_module"""

def call_operator(self, op, args, kwargs, meta):
if op != exir_ops.edge.aten.clone.default:
Expand Down
23 changes: 14 additions & 9 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(self):
self.output_format = None
self.path_for_intermediates = None
self.quantize_io = False
self.tosa_version = None
self.tosa_spec = None
self.input_order = None

def ethosu_compile_spec(
Expand Down Expand Up @@ -92,19 +92,26 @@ def ethosu_compile_spec(
if "u55" in config:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
self.tosa_spec = TosaSpecification.create_from_string(base_tosa_version)

return self

def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
def tosa_compile_spec(
self, tosa_spec: str | TosaSpecification
) -> "ArmCompileSpecBuilder":
"""
Generate compile spec for TOSA flatbuffer output
"""
assert (
self.output_format is None
), f"Output format already set: {self.output_format}"
self.output_format = "tosa"
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
if isinstance(tosa_spec, TosaSpecification):
self.tosa_spec = tosa_spec
elif isinstance(tosa_spec, str):
self.tosa_spec = TosaSpecification.create_from_string(tosa_spec)
else:
raise RuntimeError(f"Invalid type for {tosa_spec}!")
return self

def dump_intermediate_artifacts_to(
Expand Down Expand Up @@ -138,12 +145,10 @@ def build(self) -> List[CompileSpec]:
"""
Generate a list of compile spec objects from the builder
"""
assert self.tosa_version
assert self.tosa_spec

# Always supply a TOSA version
self.compile_spec = [
CompileSpec("tosa_version", str(self.tosa_version).encode())
]
self.compile_spec = [CompileSpec("tosa_version", str(self.tosa_spec).encode())]

if self.output_format == "vela":
self.compile_spec += [
Expand Down Expand Up @@ -253,7 +258,7 @@ def preprocess( # noqa: C901
# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager().transform_to_backend_pipeline(
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline(
exported_program=edge_program
)

Expand Down
11 changes: 8 additions & 3 deletions backends/arm/quantizer/arm_quantizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -24,6 +24,7 @@
from executorch.backends.arm.quantizer.quantization_annotator import annotate_graph

from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.ao.quantization.fake_quantize import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
Expand Down Expand Up @@ -205,8 +206,10 @@ def not_module_type_or_name_filter(n: Node) -> bool:


class ArmQuantizer(Quantizer):
def __init__(self) -> None:

def __init__(self, tosa_spec: TosaSpecification) -> None:
super().__init__()
self.tosa_spec = tosa_spec
self.global_config: Optional[QuantizationConfig] = None
self.io_config: Optional[QuantizationConfig] = None
self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {}
Expand Down Expand Up @@ -250,7 +253,9 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule:
Currently transforms scalar values to tensor attributes.
"""

return ArmPassManager().transform_for_annotation_pipeline(graph_module=model)
return ArmPassManager(self.tosa_spec).transform_for_annotation_pipeline(
graph_module=model
)

def annotate(self, model: GraphModule) -> GraphModule:
"""Performs the quantization annotation on the graph.
Expand Down
Loading
Loading