Skip to content

Tosa specification handling #6688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from executorch.backends.arm.operators.node_visitor import get_node_visitors
from executorch.backends.arm.operators.op_output import process_output
from executorch.backends.arm.operators.op_placeholder import process_placeholder

from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm._passes.arm_pass_manager import (
ArmPassManager,
) # usort: skip
Expand Down Expand Up @@ -86,16 +88,23 @@ def ethosu_compile_spec(
if extra_flags is not None:
self.compiler_flags.append(extra_flags)

base_tosa_version = "TOSA-0.80.0+BI"
if "U55" in config:
# Add the Ethos-U55 extension marker
base_tosa_version += "+u55"
self.tosa_version = TosaSpecification.create_from_string(base_tosa_version)

return self

def tosa_compile_spec(self) -> "ArmCompileSpecBuilder":
def tosa_compile_spec(self, tosa_version: str) -> "ArmCompileSpecBuilder":
"""
Generate compile spec for TOSA flatbuffer output
"""
assert (
self.output_format is None
), f"Output format already set: {self.output_format}"
self.output_format = "tosa"
self.tosa_version = TosaSpecification.create_from_string(tosa_version)
return self

def dump_intermediate_artifacts_to(
Expand Down Expand Up @@ -129,6 +138,13 @@ def build(self) -> List[CompileSpec]:
"""
Generate a list of compile spec objects from the builder
"""
assert self.tosa_version

# Always supply a TOSA version
self.compile_spec = [
CompileSpec("tosa_version", str(self.tosa_version).encode())
]

if self.output_format == "vela":
self.compile_spec += [
CompileSpec("output_format", "vela".encode()),
Expand Down Expand Up @@ -210,25 +226,32 @@ def preprocess( # noqa: C901
if not output_format:
raise RuntimeError("output format is required")

tosa_spec = TosaSpecification.create_from_compilespecs(compile_spec)
assert (
tosa_spec is not None
), "TOSA backend needs a TOSA version specified in the CompileSpec!"

if output_format == "vela" and len(compile_flags) == 0:
# Not testing for compile_flags correctness here, just that they are
# present. The compiler will give errors if they are not valid.
raise RuntimeError("compile flags are required for vela output format")

logger.info(f"Converting ExportedProgram to TOSA: {tosa_spec}")

# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager().transform_to_backend_pipeline(
exported_program=edge_program, compile_spec=compile_spec
)

node_visitors = get_node_visitors(edge_program)
node_visitors = get_node_visitors(edge_program, tosa_spec)

for node in graph_module.graph.nodes:
if node.op == "call_function":
process_call_function(node, tosa_graph, node_visitors)
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
elif node.op == "placeholder":
process_placeholder(node, tosa_graph, edge_program)
process_placeholder(node, tosa_graph, edge_program, tosa_spec)
elif node.op == "output":
process_output(node, tosa_graph)
else:
Expand Down
36 changes: 31 additions & 5 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023 Arm Limited and/or its affiliates.
# Copyright 2023-2024 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,6 +10,7 @@
import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from torch.export import ExportedProgram


Expand All @@ -18,8 +19,19 @@ class NodeVisitor:
Node Visitor pattern for lowering edge IR to TOSA
"""

def __init__(self, exported_program: ExportedProgram):
# Add the currently supported node_visitor specs as default.
# This should be overriden in the NodeVisitor subclasses to target
# a specific TOSA version.
# When all node_visitors has been refactored to target a specific
# version, this list should be removed.
tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, exported_program: ExportedProgram, tosa_spec: TosaSpecification):
self._exported_program = exported_program or None
self.tosa_spec = tosa_spec

def define_node(
self,
Expand All @@ -33,16 +45,30 @@ def define_node(


# container for all node visitors
_node_visitor_dict = {}
_node_visitor_dicts = {
TosaSpecification.create_from_string("TOSA-0.80.0+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80.0+MI"): {},
}


def register_node_visitor(visitor):
_node_visitor_dict[visitor.target] = visitor
for tosa_spec in visitor.tosa_specs:
_node_visitor_dicts[tosa_spec][visitor.target] = visitor
return visitor


def get_node_visitors(*args) -> Dict[str, NodeVisitor]:
node_visitors = {}
for target, visitor in _node_visitor_dict.items():
tosa_spec = None
for arg in args:
if isinstance(arg, TosaSpecification):
tosa_spec = arg
break

if tosa_spec is None:
raise RuntimeError("No TOSA specification supplied.")

for target, visitor in _node_visitor_dicts[tosa_spec].items():
node_visitors[target] = visitor(*args)

return node_visitors
73 changes: 60 additions & 13 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,25 @@
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
)
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
from serializer.tosa_serializer import TosaOp
from torch.fx import Node


@register_node_visitor
class AddVisitor(NodeVisitor):
class AddVisitor_080_BI(NodeVisitor):
target = "aten.add.Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+BI"),
]

def __init__(self, *args):
super().__init__(*args)

Expand All @@ -35,9 +41,22 @@ def define_node(
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
input_nodes = tutils.get_two_inputs(node)
input_nodes = tutils.get_two_inputs(node)

if not is_quant_node and not all(
tensor.meta["val"].dtype in (torch.int8, torch.int32)
for tensor in input_nodes
):
raise RuntimeError(
f"Unexpected non quantized {AddVisitor_080_BI.target} node."
)

needs_rescale = not (
all(tensor.meta["val"].dtype == torch.int32 for tensor in input_nodes)
and node.meta["val"].dtype == torch.int32
)

if needs_rescale:
# Rescale inputs to 32 bit
rescaled_inputs, scale = tqutils.rescale_nodes_to_int32(
input_nodes, tosa_graph
Expand All @@ -48,20 +67,48 @@ def define_node(
rescaled_inputs[0].shape, rescaled_inputs[0].shape
)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
add_output = output
rescaled_inputs = inputs

# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[add_output.name],
None,
)
# Do the INT32 Add
tosa_graph.addOperator(
TosaOp.Op().ADD,
[
rescaled_inputs[0].name,
rescaled_inputs[1].name,
],
[add_output.name],
None,
)

if needs_rescale:
# Scale output back to 8 bit
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)


@register_node_visitor
class AddVisitor_080_MI(AddVisitor_080_BI):
# inheriting 'target' from BI class

tosa_specs = [
TosaSpecification.create_from_string("TOSA-0.80.0+MI"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
inputs: List[TosaArg],
output: TosaArg,
is_quant_node: bool,
) -> None:
if is_quant_node:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output, is_quant_node)
else:
# FP32 Add lowering
tosa_graph.addOperator(
Expand Down
12 changes: 10 additions & 2 deletions backends/arm/operators/op_placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_quant_node_args,
is_quant_arg,
)
from executorch.backends.arm.tosa_specification import TosaSpecification
from executorch.backends.arm.tosa_utils import (
is_bias_node_for_quantized_addmm,
is_bias_node_for_quantized_conv,
Expand All @@ -26,6 +27,7 @@
def process_inputs(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
tosa_spec: TosaSpecification,
):
"""Serialize an input node"""
# inputs need to be in default dim_order (contiguous memory format)
Expand Down Expand Up @@ -95,6 +97,7 @@ def process_inputs_to_parameters(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
"""Serialize bias and non-quantized weights"""
inputs = [TosaArg(node)]
Expand All @@ -106,9 +109,13 @@ def process_inputs_to_parameters(

if is_bias_node_for_quantized_addmm(node) or is_bias_node_for_quantized_conv(node):
# BI bias
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
process_quantized_bias(node, tosa_graph, parameter_values)
else:
# MI weights or bias
if inputs[0].dtype == torch.float32:
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"

parameter_values = np.transpose(parameter_values, inputs[0].dim_order)

tosa_graph.addConst(
Expand Down Expand Up @@ -158,15 +165,16 @@ def process_placeholder(
node: torch.fx.Node,
tosa_graph: ts.TosaSerializer,
edge_program: ExportedProgram,
tosa_spec: TosaSpecification,
):
"""Wrapper for processing and serializing all types of placeholders"""
assert node.name == node.target, "Expect placeholder name and target to match"
assert 0 == len(node.args), "Can't handle default input values"

if node.name in edge_program.graph_signature.user_inputs:
process_inputs(node, tosa_graph)
process_inputs(node, tosa_graph, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_parameters:
process_inputs_to_parameters(node, tosa_graph, edge_program)
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
elif node.name in edge_program.graph_signature.inputs_to_buffers:
process_inputs_to_buffers(node, tosa_graph, edge_program)
elif node.name in edge_program.graph_signature.inputs_to_lifted_tensor_constants:
Expand Down
10 changes: 6 additions & 4 deletions backends/arm/test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,16 +177,18 @@ def maybe_get_tosa_collate_path() -> str | None:


def get_tosa_compile_spec(
permute_memory_to_nhwc=True, custom_path=None
tosa_version: str, permute_memory_to_nhwc=True, custom_path=None
) -> list[CompileSpec]:
"""
Default compile spec for TOSA tests.
"""
return get_tosa_compile_spec_unbuilt(permute_memory_to_nhwc, custom_path).build()
return get_tosa_compile_spec_unbuilt(
tosa_version, permute_memory_to_nhwc, custom_path
).build()


def get_tosa_compile_spec_unbuilt(
permute_memory_to_nhwc=False, custom_path=None
tosa_version: str, permute_memory_to_nhwc=False, custom_path=None
) -> ArmCompileSpecBuilder:
"""Get the ArmCompileSpecBuilder for the default TOSA tests, to modify
the compile spec before calling .build() to finalize it.
Expand All @@ -202,7 +204,7 @@ def get_tosa_compile_spec_unbuilt(
os.makedirs(intermediate_path, exist_ok=True)
compile_spec_builder = (
ArmCompileSpecBuilder()
.tosa_compile_spec()
.tosa_compile_spec(tosa_version)
.set_permute_memory_format(permute_memory_to_nhwc)
.dump_intermediate_artifacts_to(intermediate_path)
)
Expand Down
Loading
Loading