Skip to content

Arm backend: Add support for TOSA 1.0 serializer #10135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

return self._transform(exported_program.graph_module)

def _tosa_1_0_int_quantized_pipeline(self, exported_program: ExportedProgram):
return self._tosa_080_BI_pipeline(exported_program)

def _tosa_1_0_fp_pipeline(self, exported_program: ExportedProgram):
return self._tosa_080_MI_pipeline(exported_program)

def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
"""Apply passes before transforming program to backend"""
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
return self._tosa_080_BI_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
return self._tosa_080_MI_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
return self._tosa_1_0_fp_pipeline(exported_program)
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
return self._tosa_1_0_int_quantized_pipeline(exported_program)
else:
raise NotImplementedError(
f"No pass pipeline implemented for {self.tosa_spec=}"
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
Expand Down
1 change: 1 addition & 0 deletions backends/arm/operator_support/minmax_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class MinMaxSupported(SupportedTOSAOperatorCheck):
# TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer"
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
Expand Down
4 changes: 4 additions & 0 deletions backends/arm/operator_support/pool_2d_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
Expand Down Expand Up @@ -94,6 +96,8 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/reduce_sum_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ class SumSupported(SupportedTOSAOperatorCheck):
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/right_shift_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/slice_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class SliceCopySupported(SupportedTOSAOperatorCheck):
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]
Expand Down
2 changes: 2 additions & 0 deletions backends/arm/operator_support/tosa_supported_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def is_node_tosa_supported(
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
TosaSpecification.create_from_string("TOSA-0.80+BI"): [],
TosaSpecification.create_from_string("TOSA-0.80+MI"): [],
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
}


Expand Down
16 changes: 12 additions & 4 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

# pyre-unsafe

from typing import Dict, List
from typing import Any, Dict, List

import torch

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.export import ExportedProgram
Expand All @@ -25,19 +24,26 @@ class NodeVisitor:
# a specific TOSA version.
# When all node_visitors has been refactored to target a specific
# version, this list should be removed.
tosa_specs = [
tosa_specs_1_00 = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
TosaSpecification.create_from_string("TOSA-1.0+FP"),
]

tosa_specs_0_80 = [
TosaSpecification.create_from_string("TOSA-0.80+BI"),
TosaSpecification.create_from_string("TOSA-0.80+MI"),
]

tosa_specs = tosa_specs_0_80 + tosa_specs_1_00

def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
self._exported_program = exported_program
self.tosa_spec = tosa_spec

def define_node(
self,
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:
Expand All @@ -48,6 +54,8 @@ def define_node(
_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
TosaSpecification.create_from_string("TOSA-1.0+INT"): {},
TosaSpecification.create_from_string("TOSA-1.0+FP"): {},
}


Expand Down
31 changes: 21 additions & 10 deletions backends/arm/process_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@
#

# pyre-unsafe
from typing import cast, Dict
from typing import Any, cast, Dict

import numpy as np
import torch
import torch.fx
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import NodeVisitor
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
from torch._export.utils import (
get_buffer,
Expand All @@ -28,7 +31,7 @@

def process_call_function(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
node_visitors: Dict[str, NodeVisitor],
tosa_spec: TosaSpecification,
):
Expand Down Expand Up @@ -63,7 +66,7 @@ def process_call_function(

def process_inputs(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
tosa_spec: TosaSpecification,
):
"""Serialize an input node"""
Expand All @@ -81,6 +84,14 @@ def process_inputs(
f"Failed processing input placeholder: {node.name}. "
"Is the original torch function supported?"
) from e

if isinstance(tosa_spec, Tosa_0_80):
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
elif isinstance(tosa_spec, Tosa_1_00):
import serializer.tosa_serializer as ts
else:
raise ValueError(f"Unsupported TOSA spec: {tosa_spec}")

input_shape = tosa_arg.shape
input_dim_order = tosa_arg.dim_order
tensor = ts.TosaSerializerTensor(
Expand All @@ -95,7 +106,7 @@ def process_inputs(

def process_inputs_to_parameters(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
Expand Down Expand Up @@ -124,7 +135,7 @@ def process_inputs_to_parameters(

def process_inputs_to_buffers(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
edge_program: ExportedProgram,
):
"""Serialize quantized weights"""
Expand Down Expand Up @@ -152,7 +163,7 @@ def process_inputs_to_buffers(

def process_inputs_to_lifted_tensor_constants(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
edge_program: ExportedProgram,
):
try:
Expand All @@ -172,7 +183,7 @@ def process_inputs_to_lifted_tensor_constants(

def process_placeholder(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
Expand All @@ -198,7 +209,7 @@ def process_placeholder(

def process_output(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
):
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
tosa_graph.addOutputTensor(
Expand Down
5 changes: 4 additions & 1 deletion backends/arm/scripts/install_reference_model.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ tosa_reference_model_url="https://git.gitlab.arm.com/tosa/tosa-reference-model.g
tosa_reference_model_0_80_branch="v0.80"
tosa_reference_model_0_80_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"
tosa_serialization_lib_0_80_rev="v0.80.1"
tosa_reference_model_1_0_rev="v1.0"
tosa_reference_model_1_0_rev="f9b4ceb850964be03a39e965ad7a0546dc6c57ae"

script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)

Expand Down Expand Up @@ -47,6 +47,9 @@ function setup_tosa_reference_model() {
# Vela's flatbuffer requirement is expected to loosen, then remove this. MLETORCH-565
CMAKE_POLICY_VERSION_MINIMUM=3.5 pip install . --no-dependencies flatbuffers
popd

# Install the 1.0 branch from upstream
CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 pip install "tosa-tools@git+${tosa_reference_model_url}@${tosa_reference_model_1_0_rev}" ml_dtypes==0.5.1 --no-dependencies flatbuffers
}

setup_tosa_reference_model $1
64 changes: 37 additions & 27 deletions backends/arm/test/runner_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,24 @@

from pathlib import Path

from typing import cast, Dict, List, Literal, Optional, Tuple
from typing import Any, cast, Dict, List, Literal, Optional, Tuple

import numpy as np
import torch

try:
import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model
except ImportError:
tosa_reference_model = None
from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa

from executorch.backends.arm.test.conftest import is_option_enabled
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.exir import ExecutorchProgramManager, ExportedProgram
from executorch.exir.backend.compile_spec_schema import CompileSpec
from executorch.exir.lowered_backend_module import LoweredBackendModule
from packaging.version import Version
from torch.fx.node import Node

from torch.overrides import TorchFunctionMode
from tosa_tools.v0_80.tosa import TosaGraph

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -566,33 +563,46 @@ def arm_executor_runner_exists(target_board):


def run_tosa_graph(
graph: TosaGraph,
graph: Any,
tosa_version: TosaSpecification,
inputs: list[torch.Tensor],
) -> list[torch.Tensor]:
"""Runs the TOSA reference model with inputs and returns the result."""
inputs_np = [input.numpy() for input in inputs]
transpose_data_format(inputs_np, to="NHWC")

tosa_release = tosa_version.version

if tosa_release > Version("0.80"):
logger.warning("The reference model is only tested for TOSA v0.80")

# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
tosa_profile = 1 if tosa_version.support_float() else 0
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
outputs_np, status = tosa_reference_model.run(
graph,
inputs_np,
verbosity=_tosa_refmodel_loglevel(logger.level),
tosa_profile=tosa_profile,
initialize_variable_tensor_from_numpy=1, # True
debug_mode=debug_mode,
)
if isinstance(tosa_version, Tosa_0_80):
import tosa_tools.v0_80.tosa_reference_model as reference_model

# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
tosa_profile = 1 if tosa_version.support_float() else 0
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
outputs_np, status = reference_model.run(
graph,
inputs_np,
verbosity=_tosa_refmodel_loglevel(logger.level),
tosa_profile=tosa_profile,
initialize_variable_tensor_from_numpy=True,
debug_mode=debug_mode,
)
elif isinstance(tosa_version, Tosa_1_00):
import tosa_reference_model as reference_model

debug_mode = "ALL" if logger.level <= logging.DEBUG else None
outputs_np, status = reference_model.run(
graph,
inputs_np,
verbosity=_tosa_refmodel_loglevel(logger.level),
initialize_variable_tensor_from_numpy=True,
debug_mode=debug_mode,
)
else:
raise ValueError(
f"Unknown TOSA specification: {tosa_version}. No refererence model available to run for this specification version"
)

assert (
status == tosa_reference_model.GraphStatus.TOSA_VALID
status == reference_model.GraphStatus.TOSA_VALID
), "Non-valid TOSA given to reference model."

transpose_data_format(outputs_np, to="NCHW")
Expand Down
Loading
Loading