Skip to content

Commit 6ceb26c

Browse files
committed
Arm backend: Annotate types
Annotate types for files in arm/operators/* and arm/tosa_*. This was done by with the tool monkeytype in combination with pytest. Change-Id: I295e29e7adf7d7621ba9154c1cdc445d09b8c926
1 parent 13b5605 commit 6ceb26c

File tree

6 files changed

+58
-47
lines changed

6 files changed

+58
-47
lines changed

backends/arm/operators/op_avg_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _build_generic_avgpool2d(
4141
output: TosaArg,
4242
input_zp: int,
4343
output_zp: int,
44-
accumulator_type,
44+
accumulator_type: ts.DType,
4545
) -> None:
4646
input_tensor = inputs[0]
4747

backends/arm/operators/op_conv2d.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from executorch.backends.arm.tosa_quant_utils import build_rescale_conv_output
2323
from executorch.backends.arm.tosa_utils import build_reshape, tosa_shape
2424

25-
from serializer.tosa_serializer import TosaOp
26-
2725

2826
@register_node_visitor
2927
class Conv2dVisitor(NodeVisitor):
@@ -36,8 +34,12 @@ def __init__(self, *args):
3634
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
3735
# must be an integer, but tosa currently strictly require this property.
3836
# This function adjusts the pad value to meet the requirement.
39-
def adjust_pad_if_needed(self, input, weight, stride, pad, dilation):
40-
mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride
37+
def adjust_pad_if_needed(
38+
self, input_size: int, input_weight: int, stride: int, pad: int, dilation: int
39+
) -> int:
40+
mod_remainder = (
41+
input_size + 2 * pad - dilation * (input_weight - 1) - 1
42+
) % stride
4143

4244
# No need to adjust
4345
if mod_remainder == 0:
@@ -143,11 +145,11 @@ def define_node(
143145
build_reshape(
144146
tosa_graph, weight.name, weight_post_shape, weight_reshaped.name
145147
)
146-
tosa_op = TosaOp.Op().DEPTHWISE_CONV2D
148+
tosa_op = ts.TosaOp.Op().DEPTHWISE_CONV2D
147149
weight_name = weight_reshaped.name
148150
else:
149151
"""Regular convolution case"""
150-
tosa_op = TosaOp.Op().CONV2D
152+
tosa_op = ts.TosaOp.Op().CONV2D
151153
weight_name = weight.name
152154

153155
tosa_graph.addOperator(

backends/arm/tosa_mapping.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# the standardised TOSA representation.
1212
#
1313

14-
from typing import Sequence
14+
from typing import Any, Sequence
1515

1616
import serializer.tosa_serializer as ts # type: ignore
1717
import torch
@@ -44,7 +44,7 @@
4444
}
4545

4646

47-
def map_dtype(data_type):
47+
def map_dtype(data_type: torch.dtype) -> ts.DType:
4848
if data_type in UNSUPPORTED_DTYPES:
4949
raise ValueError(f"Unsupported type: {data_type}")
5050
if data_type not in DTYPE_MAP:
@@ -88,7 +88,7 @@ def __process_list(self, argument):
8888
def __process_number(self, argument: float | int):
8989
self.number = argument
9090

91-
def __init__(self, argument) -> None:
91+
def __init__(self, argument: Any) -> None:
9292
self.name = None # type: ignore[assignment]
9393
self.dtype = None
9494
self.shape = None

backends/arm/tosa_quant_utils.py

Lines changed: 43 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,18 @@
88
# Utiliy functions for TOSA quantized lowerings
99

1010
import math
11-
from typing import cast, NamedTuple
11+
from typing import cast, List, NamedTuple, Tuple
12+
13+
import executorch.backends.arm.tosa_mapping
1214

1315
import serializer.tosa_serializer as ts # type: ignore
1416
import torch.fx
17+
import torch.fx.node
1518
import tosa.Op as TosaOp # type: ignore
1619
from executorch.backends.arm.tosa_mapping import TosaArg
1720
from executorch.exir.dialects._ops import ops as exir_ops
18-
from serializer.tosa_serializer import TosaSerializerTensor
21+
from serializer.tosa_serializer import TosaSerializer, TosaSerializerTensor
22+
from torch import Tensor
1923
from torch.fx import Node
2024

2125

@@ -116,7 +120,7 @@ class QuantArgs(NamedTuple):
116120
qmax: int
117121
dtype: torch.dtype
118122

119-
def quantize_value(self, x):
123+
def quantize_value(self, x: torch.Tensor | float) -> Tensor:
120124
if not isinstance(x, torch.Tensor):
121125
x = torch.Tensor([x])
122126
return torch.clip(
@@ -144,15 +148,15 @@ def from_operator(cls, op, args):
144148

145149

146150
# Check if scale32 mode is used for given output element type
147-
def is_scale32(type):
151+
def is_scale32(type: int) -> ts.DType:
148152
return type == ts.DType.INT8
149153

150154

151155
# TOSA uses the RESCALE operation to scale between values with differing precision.
152156
# The RESCALE operator is defined using an integer multiply, add, and shift.
153157
# This utility function is for calculating the multier and shift given a scale.
154158
# Ref: https://www.mlplatform.org/tosa/tosa_spec.html#_precision_scaling
155-
def compute_multiplier_and_shift(scale, scaleWidth=32):
159+
def compute_multiplier_and_shift(scale: float, scaleWidth: int = 32) -> Tuple[int, int]:
156160
if scaleWidth == 16:
157161
offset = 15
158162
elif scaleWidth == 32:
@@ -166,12 +170,12 @@ def compute_multiplier_and_shift(scale, scaleWidth=32):
166170
shift = exponent
167171

168172
const_2_power_15_or_31 = 1 << offset
169-
shifted_mantissa = round(mantissa * const_2_power_15_or_31)
173+
shifted_mantissa = int(round(mantissa * const_2_power_15_or_31))
170174

171175
assert shifted_mantissa <= const_2_power_15_or_31
172176

173177
if shifted_mantissa == const_2_power_15_or_31:
174-
shifted_mantissa = shifted_mantissa / 2
178+
shifted_mantissa = int(shifted_mantissa / 2)
175179
shift += 1
176180

177181
# TOSA expects right shift to be positive, and embed (1 << offset) into right shift bits.
@@ -189,15 +193,15 @@ def compute_multiplier_and_shift(scale, scaleWidth=32):
189193

190194

191195
def build_rescale(
192-
tosa_fb,
193-
scale,
194-
input_node,
195-
output_name,
196-
output_type,
197-
output_shape,
198-
input_zp,
199-
output_zp,
200-
is_double_round=False,
196+
tosa_fb: TosaSerializer,
197+
scale: float,
198+
input_node: TosaSerializerTensor,
199+
output_name: str,
200+
output_type: ts.DType,
201+
output_shape: List[int],
202+
input_zp: int,
203+
output_zp: int,
204+
is_double_round: bool = False,
201205
):
202206
scale_width = 32 if is_scale32(output_type) else 16
203207
multiplier, shift = compute_multiplier_and_shift(scale, scale_width)
@@ -223,7 +227,12 @@ def build_rescale(
223227

224228

225229
def build_rescale_to_int32(
226-
tosa_fb, input, input_zp, rescale_scale, is_scale32=True, is_double_round=False
230+
tosa_fb: TosaSerializer,
231+
input_arg: executorch.backends.arm.tosa_mapping.TosaArg,
232+
input_zp: int,
233+
rescale_scale: float,
234+
is_scale32: bool = True,
235+
is_double_round: bool = False,
227236
) -> TosaSerializerTensor:
228237
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
229238
attr_rescale = ts.TosaSerializerAttribute()
@@ -238,10 +247,10 @@ def build_rescale_to_int32(
238247
input_unsigned=False,
239248
output_unsigned=False,
240249
)
241-
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input.shape, ts.DType.INT32)
250+
input_A_rescaled_to_int32 = tosa_fb.addIntermediate(input_arg.shape, ts.DType.INT32)
242251
tosa_fb.addOperator(
243252
TosaOp.Op().RESCALE,
244-
[input.name],
253+
[input_arg.name],
245254
[input_A_rescaled_to_int32.name],
246255
attr_rescale,
247256
)
@@ -250,13 +259,13 @@ def build_rescale_to_int32(
250259

251260

252261
def build_rescale_from_int32(
253-
tosa_fb,
254-
input_name,
255-
output_name,
256-
output_zp,
257-
rescale_scale,
258-
is_scale32=True,
259-
is_double_round=False,
262+
tosa_fb: TosaSerializer,
263+
input_name: str,
264+
output_name: str,
265+
output_zp: int,
266+
rescale_scale: float,
267+
is_scale32: bool = True,
268+
is_double_round: bool = False,
260269
) -> None:
261270
multiplier, shift = compute_multiplier_and_shift(rescale_scale)
262271
attr_rescale_output = ts.TosaSerializerAttribute()
@@ -283,14 +292,14 @@ def build_rescale_from_int32(
283292

284293

285294
def build_rescale_conv_output(
286-
tosa_fb,
287-
op,
288-
output_name,
289-
output_type,
290-
input_scale,
291-
weight_scale,
292-
output_scale,
293-
output_zp,
295+
tosa_fb: TosaSerializer,
296+
op: TosaSerializerTensor,
297+
output_name: str,
298+
output_type: ts.DType,
299+
input_scale: float,
300+
weight_scale: float,
301+
output_scale: float,
302+
output_zp: int,
294303
):
295304
# TODO add check to verify if this is a Per-channel quantization.
296305
post_conv2d_scale = (input_scale * weight_scale) / output_scale

backends/arm/tosa_specification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def __init__(self, version: Version, extras: List[str]):
112112
if len(extras) > 0:
113113
raise ValueError(f"Unhandled extras found: {extras}")
114114

115-
def __repr__(self):
115+
def __repr__(self) -> str:
116116
extensions = ""
117117
if self.level_8k:
118118
extensions += "+8k"

backends/arm/tosa_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import logging
99
import os
10-
from typing import Any
10+
from typing import Any, Tuple
1111

1212
import serializer.tosa_serializer as ts # type: ignore
1313
import torch
@@ -153,7 +153,7 @@ def get_resize_parameters(
153153
output_size: torch.Tensor,
154154
resize_mode: int,
155155
align_corners: bool,
156-
):
156+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
157157
"""Get the tosa.resize parameters based on the input and output size.
158158
159159
Args:

0 commit comments

Comments
 (0)