Skip to content

Arm backend: Annotate types #8723

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _build_generic_avgpool2d(
output: TosaArg,
input_zp: int,
output_zp: int,
accumulator_type,
accumulator_type: ts.DType,
) -> None:
input_tensor = inputs[0]

Expand Down
14 changes: 8 additions & 6 deletions backends/arm/operators/op_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape

from serializer.tosa_serializer import TosaOp


@register_node_visitor
class Conv2dVisitor(NodeVisitor):
Expand All @@ -36,8 +34,12 @@ def __init__(self, *args):
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
# must be an integer, but tosa currently strictly require this property.
# This function adjusts the pad value to meet the requirement.
def adjust_pad_if_needed(self, input, weight, stride, pad, dilation):
mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride
def adjust_pad_if_needed(
self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int
) -> int:
mod_remainder = (
input_size + 2 * pad - dilation * (input_weight - 1) - 1
) % stride

# No need to adjust
if mod_remainder == 0:
Expand Down Expand Up @@ -143,11 +145,11 @@ def define_node(
build_reshape(
tosa_graph, weight.name, weight_post_shape, weight_reshaped.name
)
tosa_op = TosaOp.Op().DEPTHWISE_CONV2D
tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D
weight_name = weight_reshaped.name
else:
"""Regular convolution case"""
tosa_op = TosaOp.Op().CONV2D
tosa_op = ts.TosaOp.Op().CONV2D
weight_name = weight.name

tosa_graph.addOperator(
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/tosa_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# the standardised TOSA representation.
#

from typing import Sequence
from typing import Any, Sequence

import serializer.tosa_serializer as ts # type: ignore
import torch
Expand Down Expand Up @@ -44,7 +44,7 @@
}


def map_dtype(data_type):
def map_dtype(data_type: torch.dtype) -> ts.DType:
if data_type in UNSUPPORTED_DTYPES:
raise ValueError(f"Unsupported type: {data_type}")
if data_type not in DTYPE_MAP:
Expand Down Expand Up @@ -88,7 +88,7 @@ def __process_list(self, argument):
def __process_number(self, argument: float | int):
self.number = argument

def __init__(self, argument) -> None:
def __init__(self, argument: Any) -> None:
self.name = None # type: ignore[assignment]
self.dtype = None
self.shape = None
Expand Down
77 changes: 43 additions & 34 deletions backends/arm/tosa_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,18 @@
# Utiliy functions for TOSA quantized lowerings

import math
from typing import cast, NamedTuple
from typing import cast, List, NamedTuple, Tuple

import executorch.backends.arm.tosa_mapping

import serializer.tosa_serializer as ts # type: ignore
import torch.fx
import torch.fx.node
import tosa.Op as TosaOp # type: ignore
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.exir.dialects._ops import ops as exir_ops
from serializer.tosa_serializer import TosaSerializerTensor
from serializer.tosa_serializer import TosaSerializer, TosaSerializerTensor
from torch import Tensor
from torch.fx import Node


Expand Down Expand Up @@ -116,7 +120,7 @@ class QuantArgs(NamedTuple):
qmax: int
dtype: torch.dtype

def quantize_value(self, x):
def quantize_value(self, x: torch.Tensor | float) -> Tensor:
if not isinstance(x, torch.Tensor):
x = torch.Tensor([x])
return torch.clip(
Expand Down Expand Up @@ -144,15 +148,15 @@ def from_operator(cls, op, args):


# Check if scale32 mode is used for given output element type
def is_scale32(type):
def is_scale32(type: int) -> ts.DType:
return type == ts.DType.INT8


# TOSA uses the RESCALE operation to scale between values with differing precision.
# The RESCALE operator is defined using an integer multiply, add, and shift.
# This utility function is for calculating the multier and shift given a scale.
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
def compute_multiplier_and_shift(scale, scaleWidth=32):
def compute_multiplier_and_shift(scale: float, scaleWidth: int = 32) -> Tuple[int, int]:
if scaleWidth == 16:
offset = 15
elif scaleWidth == 32:
Expand All @@ -166,12 +170,12 @@ def compute_multiplier_and_shift(scale, scaleWidth=32):
shift = exponent

const_2_power_15_or_31 = 1 << offset
shifted_mantissa = round(mantissa * const_2_power_15_or_31)
shifted_mantissa = int(round(mantissa * const_2_power_15_or_31))

assert shifted_mantissa <= const_2_power_15_or_31

if shifted_mantissa == const_2_power_15_or_31:
shifted_mantissa = shifted_mantissa / 2
shifted_mantissa = int(shifted_mantissa / 2)
shift += 1

# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
Expand All @@ -189,15 +193,15 @@ def compute_multiplier_and_shift(scale, scaleWidth=32):


def build_rescale(
tosa_fb,
scale,
input_node,
output_name,
output_type,
output_shape,
input_zp,
output_zp,
is_double_round=False,
tosa_fb: TosaSerializer,
scale: float,
input_node: TosaSerializerTensor,
output_name: str,
output_type: ts.DType,
output_shape: List[int],
input_zp: int,
output_zp: int,
is_double_round: bool = False,
):
scale_width = 32 if is_scale32(output_type) else 16
multiplier, shift = compute_multiplier_and_shift(scale, scale_width)
Expand All @@ -223,7 +227,12 @@ def build_rescale(


def build_rescale_to_int32(
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False
tosa_fb: TosaSerializer,
input_arg: executorch.backends.arm.tosa_mapping.TosaArg,
input_zp: int,
rescale_scale: float,
is_scale32: bool = True,
is_double_round: bool = False,
) -> TosaSerializerTensor:
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
attr_rescale = ts.TosaSerializerAttribute()
Expand All @@ -238,10 +247,10 @@ def build_rescale_to_int32(
input_unsigned=False,
output_unsigned=False,
)
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32)
tosa_fb.addOperator(
TosaOp.Op().RESCALE,
[input.name],
[input_arg.name],
[input_A_rescaled_to_int32.name],
attr_rescale,
)
Expand All @@ -250,13 +259,13 @@ def build_rescale_to_int32(


def build_rescale_from_int32(
tosa_fb,
input_name,
output_name,
output_zp,
rescale_scale,
is_scale32=True,
is_double_round=False,
tosa_fb: TosaSerializer,
input_name: str,
output_name: str,
output_zp: int,
rescale_scale: float,
is_scale32: bool = True,
is_double_round: bool = False,
) -> None:
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
attr_rescale_output = ts.TosaSerializerAttribute()
Expand All @@ -283,14 +292,14 @@ def build_rescale_from_int32(


def build_rescale_conv_output(
tosa_fb,
op,
output_name,
output_type,
input_scale,
weight_scale,
output_scale,
output_zp,
tosa_fb: TosaSerializer,
op: TosaSerializerTensor,
output_name: str,
output_type: ts.DType,
input_scale: float,
weight_scale: float,
output_scale: float,
output_zp: int,
):
# TODO add check to verify if this is a Per-channel quantization.
post_conv2d_scale = (input_scale * weight_scale) / output_scale
Expand Down
2 changes: 1 addition & 1 deletion backends/arm/tosa_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def __init__(self, version: Version, extras: List[str]):
if len(extras) > 0:
raise ValueError(f"Unhandled extras found: {extras}")

def __repr__(self):
def __repr__(self) -> str:
extensions = ""
if self.level_8k:
extensions += "+8k"
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/tosa_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
import os
from typing import Any
from typing import Any, Tuple

import serializer.tosa_serializer as ts # type: ignore
import torch
Expand Down Expand Up @@ -153,7 +153,7 @@ def get_resize_parameters(
output_size: torch.Tensor,
resize_mode: int,
align_corners: bool,
):
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get the tosa.resize parameters based on the input and output size.

Args:
Expand Down
Loading