Skip to content

Commit ddc8ea6

Browse files
authored
Tosa specification handling (#6688)
Add TOSA specification details to the Arm Backend. * Mandate the need for a TOSA version in the compile spec list passed to the Arm backend and propagate the information to node visitors for serialization handling. * Add TOSA version string to all TOSA tests * Adds handling of TOSA 0.80 BI and MI profile as separate serialization handlers for ADD as an example. Signed-off-by: Per Åstrand <[email protected]>
1 parent 6d6630e commit ddc8ea6

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+618
-133
lines changed

backends/arm/arm_backend.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
2121
from executorch.backends.arm.operators.op_output import process_output
2222
from executorch.backends.arm.operators.op_placeholder import process_placeholder
23+
24+
from executorch.backends.arm.tosa_specification import TosaSpecification
2325
from executorch.backends.arm._passes.arm_pass_manager import (
2426
ArmPassManager,
2527
) # usort: skip
@@ -86,16 +88,23 @@ def ethosu_compile_spec(
8688
if extra_flags is not None:
8789
self.compiler_flags.append(extra_flags)
8890

91+
base_tosa_version = "TOSA-0.80.0+BI"
92+
if "U55" in config:
93+
# Add the Ethos-U55 extension marker
94+
base_tosa_version += "+u55"
95+
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)
96+
8997
return self
9098

91-
def tosa_compile_spec(self) -> "ArmCompileSpecBuilder":
99+
def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
92100
"""
93101
Generate compile spec for TOSA flatbuffer output
94102
"""
95103
assert (
96104
self.output_format is None
97105
), f"Output format already set: {self.output_format}"
98106
self.output_format = "tosa"
107+
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
99108
return self
100109

101110
def dump_intermediate_artifacts_to(
@@ -129,6 +138,13 @@ def build(self) -> List[CompileSpec]:
129138
"""
130139
Generate a list of compile spec objects from the builder
131140
"""
141+
assert self.tosa_version
142+
143+
# Always supply a TOSA version
144+
self.compile_spec = [
145+
CompileSpec("tosa_version", str(self.tosa_version).encode())
146+
]
147+
132148
if self.output_format == "vela":
133149
self.compile_spec += [
134150
CompileSpec("output_format", "vela".encode()),
@@ -210,25 +226,32 @@ def preprocess( # noqa: C901
210226
if not output_format:
211227
raise RuntimeError("output format is required")
212228

229+
tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
230+
assert (
231+
tosa_spec is not None
232+
), "TOSA backend needs a TOSA version specified in the CompileSpec!"
233+
213234
if output_format == "vela" and len(compile_flags) == 0:
214235
# Not testing for compile_flags correctness here, just that they are
215236
# present. The compiler will give errors if they are not valid.
216237
raise RuntimeError("compile flags are required for vela output format")
217238

239+
logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")
240+
218241
# Converted output for this subgraph, serializer needs path early as it emits
219242
# const data directly. Path created and data written only in debug builds.
220243
tosa_graph = ts.TosaSerializer(artifact_path)
221244
graph_module = ArmPassManager().transform_to_backend_pipeline(
222245
exported_program=edge_program, compile_spec=compile_spec
223246
)
224247

225-
node_visitors = get_node_visitors(edge_program)
248+
node_visitors = get_node_visitors(edge_program, tosa_spec)
226249

227250
for node in graph_module.graph.nodes:
228251
if node.op == "call_function":
229-
process_call_function(node, tosa_graph, node_visitors)
252+
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
230253
elif node.op == "placeholder":
231-
process_placeholder(node, tosa_graph, edge_program)
254+
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
232255
elif node.op == "output":
233256
process_output(node, tosa_graph)
234257
else:

backends/arm/operators/node_visitor.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,6 +10,7 @@
1010
import serializer.tosa_serializer as ts
1111
import torch
1212
from executorch.backends.arm.tosa_mapping import TosaArg
13+
from executorch.backends.arm.tosa_specification import TosaSpecification
1314
from torch.export import ExportedProgram
1415

1516

@@ -18,8 +19,19 @@ class NodeVisitor:
1819
Node Visitor pattern for lowering edge IR to TOSA
1920
"""
2021

21-
def __init__(self, exported_program: ExportedProgram):
22+
# Add the currently supported node_visitor specs as default.
23+
# This should be overriden in the NodeVisitor subclasses to target
24+
# a specific TOSA version.
25+
# When all node_visitors has been refactored to target a specific
26+
# version, this list should be removed.
27+
tosa_specs = [
28+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
29+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
30+
]
31+
32+
def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
2233
self._exported_program = exported_program or None
34+
self.tosa_spec = tosa_spec
2335

2436
def define_node(
2537
self,
@@ -33,16 +45,30 @@ def define_node(
3345

3446

3547
# container for all node visitors
36-
_node_visitor_dict = {}
48+
_node_visitor_dicts = {
49+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
50+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
51+
}
3752

3853

3954
def register_node_visitor(visitor):
40-
_node_visitor_dict[visitor.target] = visitor
55+
for tosa_spec in visitor.tosa_specs:
56+
_node_visitor_dicts[tosa_spec][visitor.target] = visitor
57+
return visitor
4158

4259

4360
def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
4461
node_visitors = {}
45-
for target, visitor in _node_visitor_dict.items():
62+
tosa_spec = None
63+
for arg in args:
64+
if isinstance(arg, TosaSpecification):
65+
tosa_spec = arg
66+
break
67+
68+
if tosa_spec is None:
69+
raise RuntimeError("No TOSA specification supplied.")
70+
71+
for target, visitor in _node_visitor_dicts[tosa_spec].items():
4672
node_visitors[target] = visitor(*args)
4773

4874
return node_visitors

backends/arm/operators/op_add.py

Lines changed: 60 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,19 +11,25 @@
1111
import executorch.backends.arm.tosa_utils as tutils
1212

1313
import serializer.tosa_serializer as ts
14+
import torch
1415
from executorch.backends.arm.operators.node_visitor import (
1516
NodeVisitor,
1617
register_node_visitor,
1718
)
1819
from executorch.backends.arm.tosa_mapping import TosaArg
20+
from executorch.backends.arm.tosa_specification import TosaSpecification
1921
from serializer.tosa_serializer import TosaOp
2022
from torch.fx import Node
2123

2224

2325
@register_node_visitor
24-
class AddVisitor(NodeVisitor):
26+
class AddVisitor_080_BI(NodeVisitor):
2527
target = "aten.add.Tensor"
2628

29+
tosa_specs = [
30+
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
31+
]
32+
2733
def __init__(self, *args):
2834
super().__init__(*args)
2935

@@ -35,9 +41,22 @@ def define_node(
3541
output: TosaArg,
3642
is_quant_node: bool,
3743
) -> None:
38-
if is_quant_node:
39-
input_nodes = tutils.get_two_inputs(node)
44+
input_nodes = tutils.get_two_inputs(node)
45+
46+
if not is_quant_node and not all(
47+
tensor.meta["val"].dtype in (torch.int8, torch.int32)
48+
for tensor in input_nodes
49+
):
50+
raise RuntimeError(
51+
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
52+
)
4053

54+
needs_rescale = not (
55+
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
56+
and node.meta["val"].dtype == torch.int32
57+
)
58+
59+
if needs_rescale:
4160
# Rescale inputs to 32 bit
4261
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
4362
input_nodes, tosa_graph
@@ -48,20 +67,48 @@ def define_node(
4867
rescaled_inputs[0].shape, rescaled_inputs[0].shape
4968
)
5069
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
70+
else:
71+
add_output = output
72+
rescaled_inputs = inputs
5173

52-
# Do the INT32 Add
53-
tosa_graph.addOperator(
54-
TosaOp.Op().ADD,
55-
[
56-
rescaled_inputs[0].name,
57-
rescaled_inputs[1].name,
58-
],
59-
[add_output.name],
60-
None,
61-
)
74+
# Do the INT32 Add
75+
tosa_graph.addOperator(
76+
TosaOp.Op().ADD,
77+
[
78+
rescaled_inputs[0].name,
79+
rescaled_inputs[1].name,
80+
],
81+
[add_output.name],
82+
None,
83+
)
6284

85+
if needs_rescale:
6386
# Scale output back to 8 bit
6487
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
88+
89+
90+
@register_node_visitor
91+
class AddVisitor_080_MI(AddVisitor_080_BI):
92+
# inheriting 'target' from BI class
93+
94+
tosa_specs = [
95+
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
96+
]
97+
98+
def __init__(self, *args):
99+
super().__init__(*args)
100+
101+
def define_node(
102+
self,
103+
node: Node,
104+
tosa_graph: ts.TosaSerializer,
105+
inputs: List[TosaArg],
106+
output: TosaArg,
107+
is_quant_node: bool,
108+
) -> None:
109+
if is_quant_node:
110+
# Call the inherited define_node for handling integers
111+
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
65112
else:
66113
# FP32 Add lowering
67114
tosa_graph.addOperator(

backends/arm/operators/op_placeholder.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
get_quant_node_args,
1515
is_quant_arg,
1616
)
17+
from executorch.backends.arm.tosa_specification import TosaSpecification
1718
from executorch.backends.arm.tosa_utils import (
1819
is_bias_node_for_quantized_addmm,
1920
is_bias_node_for_quantized_conv,
@@ -26,6 +27,7 @@
2627
def process_inputs(
2728
node: torch.fx.Node,
2829
tosa_graph: ts.TosaSerializer,
30+
tosa_spec: TosaSpecification,
2931
):
3032
"""Serialize an input node"""
3133
# inputs need to be in default dim_order (contiguous memory format)
@@ -95,6 +97,7 @@ def process_inputs_to_parameters(
9597
node: torch.fx.Node,
9698
tosa_graph: ts.TosaSerializer,
9799
edge_program: ExportedProgram,
100+
tosa_spec: TosaSpecification,
98101
):
99102
"""Serialize bias and non-quantized weights"""
100103
inputs = [TosaArg(node)]
@@ -106,9 +109,13 @@ def process_inputs_to_parameters(
106109

107110
if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node):
108111
# BI bias
112+
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
109113
process_quantized_bias(node, tosa_graph, parameter_values)
110114
else:
111115
# MI weights or bias
116+
if inputs[0].dtype == torch.float32:
117+
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
118+
112119
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
113120

114121
tosa_graph.addConst(
@@ -158,15 +165,16 @@ def process_placeholder(
158165
node: torch.fx.Node,
159166
tosa_graph: ts.TosaSerializer,
160167
edge_program: ExportedProgram,
168+
tosa_spec: TosaSpecification,
161169
):
162170
"""Wrapper for processing and serializing all types of placeholders"""
163171
assert node.name == node.target, "Expect placeholder name and target to match"
164172
assert 0 == len(node.args), "Can't handle default input values"
165173

166174
if node.name in edge_program.graph_signature.user_inputs:
167-
process_inputs(node, tosa_graph)
175+
process_inputs(node, tosa_graph, tosa_spec)
168176
elif node.name in edge_program.graph_signature.inputs_to_parameters:
169-
process_inputs_to_parameters(node, tosa_graph, edge_program)
177+
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
170178
elif node.name in edge_program.graph_signature.inputs_to_buffers:
171179
process_inputs_to_buffers(node, tosa_graph, edge_program)
172180
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:

backends/arm/test/common.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,18 @@ def maybe_get_tosa_collate_path() -> str | None:
177177

178178

179179
def get_tosa_compile_spec(
180-
permute_memory_to_nhwc=True, custom_path=None
180+
tosa_version: str, permute_memory_to_nhwc=True, custom_path=None
181181
) -> list[CompileSpec]:
182182
"""
183183
Default compile spec for TOSA tests.
184184
"""
185-
return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build()
185+
return get_tosa_compile_spec_unbuilt(
186+
tosa_version, permute_memory_to_nhwc, custom_path
187+
).build()
186188

187189

188190
def get_tosa_compile_spec_unbuilt(
189-
permute_memory_to_nhwc=False, custom_path=None
191+
tosa_version: str, permute_memory_to_nhwc=False, custom_path=None
190192
) -> ArmCompileSpecBuilder:
191193
"""Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
192194
the compile spec before calling .build() to finalize it.
@@ -202,7 +204,7 @@ def get_tosa_compile_spec_unbuilt(
202204
os.makedirs(intermediate_path, exist_ok=True)
203205
compile_spec_builder = (
204206
ArmCompileSpecBuilder()
205-
.tosa_compile_spec()
207+
.tosa_compile_spec(tosa_version)
206208
.set_permute_memory_format(permute_memory_to_nhwc)
207209
.dump_intermediate_artifacts_to(intermediate_path)
208210
)

0 commit comments

Comments
 (0)