Skip to content

Arm backend: Update Rescale and affected nodes to support TOSA 1.0 #10656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 2, 2025
Merged
9 changes: 8 additions & 1 deletion backends/arm/operator_support/convolution_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@
register_tosa_support_check,
SupportedTOSAOperatorCheck,
)
from executorch.backends.arm.tosa_specification import Tosa_0_80, TosaSpecification
from executorch.backends.arm.tosa_specification import (
Tosa_0_80,
Tosa_1_00,
TosaSpecification,
)
from executorch.exir.dialects._ops import ops as exir_ops


Expand Down Expand Up @@ -43,6 +47,9 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):

# Hardware specific constraints
if not (isinstance(tosa_spec, Tosa_0_80) and tosa_spec.is_U55_subset):
# TODO remove this once TOSA 1.0 support for u55 is added.
if isinstance(tosa_spec, Tosa_1_00) and "u55" in tosa_spec.extensions:
return False
return True
else:
return self._is_node_supported_u55(node)
Expand Down
134 changes: 129 additions & 5 deletions backends/arm/operators/op_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List
from typing import Any, List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand All @@ -33,10 +32,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand All @@ -53,7 +55,7 @@ def define_node(
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
)
) # type: ignore[possibly-undefined]
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.abs
Expand Down Expand Up @@ -96,10 +98,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
Expand Down Expand Up @@ -129,3 +134,122 @@ def define_node(
[output.name],
None,
)


@register_node_visitor
class AbsVisitor_INT(NodeVisitor):
target = "aten.abs.default"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and outputs need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)
# Handle int8 (quantized) and int32
if not (inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]):
raise ValueError(
"All inputs need to be INT8 or INT32." f"Got {inputs[0].dtype=}"
)

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node, self.tosa_specs
) # type: ignore[possibly-undefined]
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.abs
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
abs_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
abs_output = output

# Do the INT32 Abs
tosa_graph.addOperator(
ts.TosaOp.Op().ABS,
[
rescaled_inputs[0].name,
],
[abs_output.name],
None,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph, abs_output, scale_back, node, self.tosa_specs
) # type: ignore[possibly-undefined]


@register_node_visitor
class AbsVisitor_FP(AbsVisitor_INT):
# inheriting 'target' from BI class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if not (inputs[0].dtype == output.dtype):
raise ValueError(
"All inputs and output need same dtype."
f"Got {inputs[0].dtype=}, {output.dtype=}"
)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Abs lowering

if not (inputs[0].dtype == ts.DType.FP32):
raise ValueError(
"All inputs need to be FP32." f"Got {inputs[0].dtype=}"
)

if not (output.dtype == ts.DType.FP32):
raise ValueError("All outputs need to be FP32." f"Got {output.dtype=}")

# MI lowering
tosa_graph.addOperator(
ts.TosaOp.Op().ABS,
[inputs[0].name],
[output.name],
None,
)
140 changes: 133 additions & 7 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

# pyre-unsafe

from typing import List
from typing import Any, List

import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand All @@ -34,10 +33,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand All @@ -58,7 +60,7 @@ def define_node(
if len(inputs[0].shape) > len(inputs[1].shape)
else inputs[1].dim_order
)

scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node
Expand Down Expand Up @@ -90,7 +92,9 @@ def define_node(
if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]
tqutils.insert_rescale_op_to_int8(
tosa_graph, add_output, scale_back, node
) # type: ignore[possibly-undefined]


@register_node_visitor
Expand All @@ -107,10 +111,13 @@ def __init__(self, *args):
def define_node(
self,
node: Node,
tosa_graph: ts.TosaSerializer,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore

# Specification (0.80) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
Expand All @@ -130,7 +137,7 @@ def define_node(
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
)

input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs)
input1, input2 = inputs

# MI lowering
tosa_graph.addOperator(
Expand All @@ -139,3 +146,122 @@ def define_node(
[output.name],
None,
)


@register_node_visitor
class AddVisitor_INT(NodeVisitor):
target = "aten.add.Tensor"

tosa_specs = [
TosaSpecification.create_from_string("TOSA-1.0+INT"),
]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)
# Handle int8 (quantized) and int32
supported_dtypes = [ts.DType.INT8, ts.DType.INT32]
if inputs[0].dtype not in supported_dtypes:
raise TypeError(
f'IO data type needs to be {supported_dtypes}, got "{inputs[0].dtype}"'
)
scale_back = 1.0
if inputs[0].dtype == ts.DType.INT8:
rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32(
tosa_graph, inputs, node, self.tosa_specs
)
else:
# input[0].dtype == ts.DType.INT32
# Non quantized input, natively support by TOSA.ADD
rescaled_inputs = inputs

if output.dtype == ts.DType.INT8:
broadcasted_shape = tutils.tosa_shape(output.shape, output.dim_order)
add_output = tosa_graph.addIntermediate(broadcasted_shape, ts.DType.INT32)
else:
# output.dtype == ts.DType.INT32
add_output = output

input1, input2 = rescaled_inputs

# Do the INT32 Add
tosa_graph.addOperator(
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[add_output.name],
None,
)

if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(
tosa_graph, add_output, scale_back, node, self.tosa_specs
) # type: ignore[possibly-undefined]


@register_node_visitor
class AddVisitor_FP(AddVisitor_INT):
# inheriting 'target' from INT class

tosa_specs = [TosaSpecification.create_from_string("TOSA-1.0+FP")]

def __init__(self, *args):
super().__init__(*args)

def define_node(
self,
node: Node,
tosa_graph: Any,
inputs: List[TosaArg],
output: TosaArg,
) -> None:

import serializer.tosa_serializer as ts # type: ignore

# Specification (1.0) states that input and output types
# should all be the same
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
raise TypeError(
f"All IO needs to have the same data type, got input 1: "
f"{inputs[0].dtype}, input 2: {inputs[1].dtype} and output: "
f"{output.dtype}"
)

if inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32]:
# Call the inherited define_node for handling integers
super().define_node(node, tosa_graph, inputs, output)
else:
# FP32 Add lowering
if inputs[0].dtype != ts.DType.FP32:
raise TypeError(
f"Expected IO data type to be FP32, got {inputs[0].dtype}"
)

input1, input2 = inputs

# FP lowering
tosa_graph.addOperator(
ts.TosaOp.Op().ADD,
[input1.name, input2.name],
[output.name],
None,
)
Loading
Loading