Skip to content

Commit a260704

Browse files
authored
Arm backend: Add support for TOSA 1.0 serializer (#10135)
Adapt serialization and TOSA graph handling to be able to handle 1.0. Also install TOSA pip package for 1.0 alongside 0.80. ### Test plan Validate that 0.80 TOSA version test still work with the 1.0 package installed. Signed-off-by: Per Åstrand <[email protected]>
1 parent e261ed1 commit a260704

15 files changed

+118
-44
lines changed

backends/arm/_passes/arm_pass_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,22 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
166166

167167
return self._transform(exported_program.graph_module)
168168

169+
def _tosa_1_0_int_quantized_pipeline(self, exported_program: ExportedProgram):
170+
return self._tosa_080_BI_pipeline(exported_program)
171+
172+
def _tosa_1_0_fp_pipeline(self, exported_program: ExportedProgram):
173+
return self._tosa_080_MI_pipeline(exported_program)
174+
169175
def transform_to_backend_pipeline(self, exported_program: ExportedProgram):
170176
"""Apply passes before transforming program to backend"""
171177
if self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+BI"):
172178
return self._tosa_080_BI_pipeline(exported_program)
173179
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-0.80.0+MI"):
174180
return self._tosa_080_MI_pipeline(exported_program)
181+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+FP"):
182+
return self._tosa_1_0_fp_pipeline(exported_program)
183+
elif self.tosa_spec == TosaSpecification.create_from_string("TOSA-1.0+INT"):
184+
return self._tosa_1_0_int_quantized_pipeline(exported_program)
175185
else:
176186
raise NotImplementedError(
177187
f"No pass pipeline implemented for {self.tosa_spec=}"

backends/arm/operator_support/convolution_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ class ConvolutionSupported(SupportedTOSAOperatorCheck):
2222
tosa_specs = [
2323
TosaSpecification.create_from_string("TOSA-0.80+BI"),
2424
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
26+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2527
]
2628

2729
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/minmax_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class MinMaxSupported(SupportedTOSAOperatorCheck):
2222
# TODO : "MLETORCH-718 : Quantization of indices in arm_quantizer"
2323
tosa_specs = [
2424
TosaSpecification.create_from_string("TOSA-0.80+MI"),
25+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2526
]
2627

2728
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/pool_2d_support.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ class AvgPool2dSupported(SupportedTOSAOperatorCheck):
4141
tosa_specs = [
4242
TosaSpecification.create_from_string("TOSA-0.80+BI"),
4343
TosaSpecification.create_from_string("TOSA-0.80+MI"),
44+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
45+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
4446
]
4547

4648
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
@@ -94,6 +96,8 @@ class MaxPool2dSupported(SupportedTOSAOperatorCheck):
9496
tosa_specs = [
9597
TosaSpecification.create_from_string("TOSA-0.80+BI"),
9698
TosaSpecification.create_from_string("TOSA-0.80+MI"),
99+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
100+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
97101
]
98102

99103
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/reduce_sum_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ class SumSupported(SupportedTOSAOperatorCheck):
2121
tosa_specs = [
2222
TosaSpecification.create_from_string("TOSA-0.80+BI"),
2323
TosaSpecification.create_from_string("TOSA-0.80+MI"),
24+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
25+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2426
]
2527

2628
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/right_shift_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class RightShiftSupported(SupportedTOSAOperatorCheck):
2929
tosa_specs = [
3030
TosaSpecification.create_from_string("TOSA-0.80+BI"),
3131
TosaSpecification.create_from_string("TOSA-0.80+MI"),
32+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
33+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3234
]
3335

3436
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

backends/arm/operator_support/slice_copy_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class SliceCopySupported(SupportedTOSAOperatorCheck):
2525
tosa_specs = [
2626
TosaSpecification.create_from_string("TOSA-0.80+BI"),
2727
TosaSpecification.create_from_string("TOSA-0.80+MI"),
28+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
29+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
2830
]
2931

3032
def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc]

backends/arm/operator_support/to_copy_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ class ToCopySupported(SupportedTOSAOperatorCheck):
3030
tosa_specs = [
3131
TosaSpecification.create_from_string("TOSA-0.80+BI"),
3232
TosaSpecification.create_from_string("TOSA-0.80+MI"),
33+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
34+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
3335
]
3436

3537
SupportedTypeDict = dict[torch.dtype, list[torch.dtype]]

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ def is_node_tosa_supported(
6666
_tosa_spec_support: dict[TosaSpecification, list[Type[SupportedTOSAOperatorCheck]]] = {
6767
TosaSpecification.create_from_string("TOSA-0.80+BI"): [],
6868
TosaSpecification.create_from_string("TOSA-0.80+MI"): [],
69+
TosaSpecification.create_from_string("TOSA-1.0+INT"): [],
70+
TosaSpecification.create_from_string("TOSA-1.0+FP"): [],
6971
}
7072

7173

backends/arm/operators/node_visitor.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55

66
# pyre-unsafe
77

8-
from typing import Dict, List
8+
from typing import Any, Dict, List
99

1010
import torch
1111

12-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1312
from executorch.backends.arm.tosa_mapping import TosaArg
1413
from executorch.backends.arm.tosa_specification import TosaSpecification
1514
from torch.export import ExportedProgram
@@ -25,19 +24,26 @@ class NodeVisitor:
2524
# a specific TOSA version.
2625
# When all node_visitors has been refactored to target a specific
2726
# version, this list should be removed.
28-
tosa_specs = [
27+
tosa_specs_1_00 = [
28+
TosaSpecification.create_from_string("TOSA-1.0+INT"),
29+
TosaSpecification.create_from_string("TOSA-1.0+FP"),
30+
]
31+
32+
tosa_specs_0_80 = [
2933
TosaSpecification.create_from_string("TOSA-0.80+BI"),
3034
TosaSpecification.create_from_string("TOSA-0.80+MI"),
3135
]
3236

37+
tosa_specs = tosa_specs_0_80 + tosa_specs_1_00
38+
3339
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
3440
self._exported_program = exported_program
3541
self.tosa_spec = tosa_spec
3642

3743
def define_node(
3844
self,
3945
node: torch.fx.Node,
40-
tosa_graph: ts.TosaSerializer,
46+
tosa_graph: Any,
4147
inputs: List[TosaArg],
4248
output: TosaArg,
4349
) -> None:
@@ -48,6 +54,8 @@ def define_node(
4854
_node_visitor_dicts: Dict[TosaSpecification, Dict] = {
4955
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
5056
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
57+
TosaSpecification.create_from_string("TOSA-1.0+INT"): {},
58+
TosaSpecification.create_from_string("TOSA-1.0+FP"): {},
5159
}
5260

5361

backends/arm/process_node.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,18 @@
55
#
66

77
# pyre-unsafe
8-
from typing import cast, Dict
8+
from typing import Any, cast, Dict
99

1010
import numpy as np
1111
import torch
1212
import torch.fx
13-
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
1413
from executorch.backends.arm.operators.node_visitor import NodeVisitor
1514
from executorch.backends.arm.tosa_mapping import TosaArg
16-
from executorch.backends.arm.tosa_specification import TosaSpecification
15+
from executorch.backends.arm.tosa_specification import (
16+
Tosa_0_80,
17+
Tosa_1_00,
18+
TosaSpecification,
19+
)
1720
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
1821
from torch._export.utils import (
1922
get_buffer,
@@ -28,7 +31,7 @@
2831

2932
def process_call_function(
3033
node: torch.fx.Node,
31-
tosa_graph: ts.TosaSerializer,
34+
tosa_graph: Any,
3235
node_visitors: Dict[str, NodeVisitor],
3336
tosa_spec: TosaSpecification,
3437
):
@@ -63,7 +66,7 @@ def process_call_function(
6366

6467
def process_inputs(
6568
node: torch.fx.Node,
66-
tosa_graph: ts.TosaSerializer,
69+
tosa_graph: Any,
6770
tosa_spec: TosaSpecification,
6871
):
6972
"""Serialize an input node"""
@@ -81,6 +84,14 @@ def process_inputs(
8184
f"Failed processing input placeholder: {node.name}. "
8285
"Is the original torch function supported?"
8386
) from e
87+
88+
if isinstance(tosa_spec, Tosa_0_80):
89+
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
90+
elif isinstance(tosa_spec, Tosa_1_00):
91+
import serializer.tosa_serializer as ts
92+
else:
93+
raise ValueError(f"Unsupported TOSA spec: {tosa_spec}")
94+
8495
input_shape = tosa_arg.shape
8596
input_dim_order = tosa_arg.dim_order
8697
tensor = ts.TosaSerializerTensor(
@@ -95,7 +106,7 @@ def process_inputs(
95106

96107
def process_inputs_to_parameters(
97108
node: torch.fx.Node,
98-
tosa_graph: ts.TosaSerializer,
109+
tosa_graph: Any,
99110
edge_program: ExportedProgram,
100111
tosa_spec: TosaSpecification,
101112
):
@@ -124,7 +135,7 @@ def process_inputs_to_parameters(
124135

125136
def process_inputs_to_buffers(
126137
node: torch.fx.Node,
127-
tosa_graph: ts.TosaSerializer,
138+
tosa_graph: Any,
128139
edge_program: ExportedProgram,
129140
):
130141
"""Serialize quantized weights"""
@@ -152,7 +163,7 @@ def process_inputs_to_buffers(
152163

153164
def process_inputs_to_lifted_tensor_constants(
154165
node: torch.fx.Node,
155-
tosa_graph: ts.TosaSerializer,
166+
tosa_graph: Any,
156167
edge_program: ExportedProgram,
157168
):
158169
try:
@@ -172,7 +183,7 @@ def process_inputs_to_lifted_tensor_constants(
172183

173184
def process_placeholder(
174185
node: torch.fx.Node,
175-
tosa_graph: ts.TosaSerializer,
186+
tosa_graph: Any,
176187
edge_program: ExportedProgram,
177188
tosa_spec: TosaSpecification,
178189
):
@@ -198,7 +209,7 @@ def process_placeholder(
198209

199210
def process_output(
200211
node: torch.fx.Node,
201-
tosa_graph: ts.TosaSerializer,
212+
tosa_graph: Any,
202213
):
203214
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
204215
tosa_graph.addOutputTensor(

backends/arm/scripts/install_reference_model.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ tosa_reference_model_url="https://git.gitlab.arm.com/tosa/tosa-reference-model.g
1313
tosa_reference_model_0_80_branch="v0.80"
1414
tosa_reference_model_0_80_rev="70ed0b40fa831387e36abdb4f7fb9670a3464f5a"
1515
tosa_serialization_lib_0_80_rev="v0.80.1"
16-
tosa_reference_model_1_0_rev="v1.0"
16+
tosa_reference_model_1_0_rev="f9b4ceb850964be03a39e965ad7a0546dc6c57ae"
1717

1818
script_dir=$(cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
1919

@@ -47,6 +47,9 @@ function setup_tosa_reference_model() {
4747
# Vela's flatbuffer requirement is expected to loosen, then remove this. MLETORCH-565
4848
CMAKE_POLICY_VERSION_MINIMUM=3.5 pip install . --no-dependencies flatbuffers
4949
popd
50+
51+
# Install the 1.0 branch from upstream
52+
CMAKE_POLICY_VERSION_MINIMUM=3.5 BUILD_PYBIND=1 pip install "tosa-tools@git+${tosa_reference_model_url}@${tosa_reference_model_1_0_rev}" ml_dtypes==0.5.1 --no-dependencies flatbuffers
5053
}
5154

5255
setup_tosa_reference_model $1

backends/arm/test/runner_utils.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,27 +13,24 @@
1313

1414
from pathlib import Path
1515

16-
from typing import cast, Dict, List, Literal, Optional, Tuple
16+
from typing import Any, cast, Dict, List, Literal, Optional, Tuple
1717

1818
import numpy as np
1919
import torch
2020

21-
try:
22-
import tosa_tools.v0_80.tosa_reference_model as tosa_reference_model
23-
except ImportError:
24-
tosa_reference_model = None
2521
from executorch.backends.arm.arm_backend import get_tosa_spec, is_tosa
26-
2722
from executorch.backends.arm.test.conftest import is_option_enabled
28-
from executorch.backends.arm.tosa_specification import TosaSpecification
23+
from executorch.backends.arm.tosa_specification import (
24+
Tosa_0_80,
25+
Tosa_1_00,
26+
TosaSpecification,
27+
)
2928
from executorch.exir import ExecutorchProgramManager, ExportedProgram
3029
from executorch.exir.backend.compile_spec_schema import CompileSpec
3130
from executorch.exir.lowered_backend_module import LoweredBackendModule
32-
from packaging.version import Version
3331
from torch.fx.node import Node
3432

3533
from torch.overrides import TorchFunctionMode
36-
from tosa_tools.v0_80.tosa import TosaGraph
3734

3835
logger = logging.getLogger(__name__)
3936

@@ -566,33 +563,46 @@ def arm_executor_runner_exists(target_board):
566563

567564

568565
def run_tosa_graph(
569-
graph: TosaGraph,
566+
graph: Any,
570567
tosa_version: TosaSpecification,
571568
inputs: list[torch.Tensor],
572569
) -> list[torch.Tensor]:
573570
"""Runs the TOSA reference model with inputs and returns the result."""
574571
inputs_np = [input.numpy() for input in inputs]
575572
transpose_data_format(inputs_np, to="NHWC")
576573

577-
tosa_release = tosa_version.version
578-
579-
if tosa_release > Version("0.80"):
580-
logger.warning("The reference model is only tested for TOSA v0.80")
581-
582-
# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
583-
tosa_profile = 1 if tosa_version.support_float() else 0
584-
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
585-
outputs_np, status = tosa_reference_model.run(
586-
graph,
587-
inputs_np,
588-
verbosity=_tosa_refmodel_loglevel(logger.level),
589-
tosa_profile=tosa_profile,
590-
initialize_variable_tensor_from_numpy=1, # True
591-
debug_mode=debug_mode,
592-
)
574+
if isinstance(tosa_version, Tosa_0_80):
575+
import tosa_tools.v0_80.tosa_reference_model as reference_model
576+
577+
# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
578+
tosa_profile = 1 if tosa_version.support_float() else 0
579+
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
580+
outputs_np, status = reference_model.run(
581+
graph,
582+
inputs_np,
583+
verbosity=_tosa_refmodel_loglevel(logger.level),
584+
tosa_profile=tosa_profile,
585+
initialize_variable_tensor_from_numpy=True,
586+
debug_mode=debug_mode,
587+
)
588+
elif isinstance(tosa_version, Tosa_1_00):
589+
import tosa_reference_model as reference_model
590+
591+
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
592+
outputs_np, status = reference_model.run(
593+
graph,
594+
inputs_np,
595+
verbosity=_tosa_refmodel_loglevel(logger.level),
596+
initialize_variable_tensor_from_numpy=True,
597+
debug_mode=debug_mode,
598+
)
599+
else:
600+
raise ValueError(
601+
f"Unknown TOSA specification: {tosa_version}. No refererence model available to run for this specification version"
602+
)
593603

594604
assert (
595-
status == tosa_reference_model.GraphStatus.TOSA_VALID
605+
status == reference_model.GraphStatus.TOSA_VALID
596606
), "Non-valid TOSA given to reference model."
597607

598608
transpose_data_format(outputs_np, to="NCHW")

0 commit comments

Comments
 (0)