Skip to content

Commit 3374ce1

Browse files
committed
Lintrunner: Enable mypy testing on backends/arm
Migration from pyre to mypy in the lintrunner by enabling mypy for backends/arm. But, choosing to ignore the directory backends/arm/test. Adding ignores all over the place. These needs to be fixed properly in the future, but now we will start to catch new things trying to sneak in. Change-Id: Ie7f73d5688aaec3b32dca9f0cd042da94c06f487
1 parent debafbe commit 3374ce1

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+184
-172
lines changed

.lintrunner.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ code = 'MYPY'
291291
include_patterns = [
292292
# TODO(https://github.com/pytorch/executorch/issues/7441): Gradually start enabling all folders.
293293
# 'backends/**/*.py',
294+
'backends/arm/**/*.py',
294295
'build/**/*.py',
295296
'codegen/**/*.py',
296297
# 'devtools/**/*.py',
@@ -312,6 +313,7 @@ exclude_patterns = [
312313
'**/third-party/**',
313314
'scripts/check_binary_dependencies.py',
314315
'profiler/test/test_profiler_e2e.py',
316+
'backends/arm/test/**',
315317
]
316318
command = [
317319
'python',

backends/arm/_passes/annotate_channels_last_dim_order_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def call(self, graph_module: torch.fx.GraphModule):
209209
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
210210
dim_order = self.HWCM_order
211211
else:
212-
dim_order = tuple(range(node_data.dim()))
212+
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
213213
node.meta["tosa_dim_order"] = dim_order
214214
# Take care of cases when:
215215
# 4D (NHWC) -> >4D (NCH)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from executorch.backends.arm._passes.convert_split_to_slice import (
2222
ConvertSplitToSlicePass,
2323
)
24-
from executorch.backends.arm._passes.convert_squeezes_to_view import (
24+
from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
2525
ConvertSqueezesToViewPass,
2626
)
2727
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
@@ -30,7 +30,9 @@
3030
)
3131
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
3232
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
33-
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
33+
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
34+
DecomposeSelectPass,
35+
)
3436
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
3537
DecomposeSoftmaxesPass,
3638
)
@@ -40,18 +42,20 @@
4042
QuantizeFullArgument,
4143
RetraceFoldedDtypesPass,
4244
)
43-
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
45+
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
4446
FuseQuantizedActivationPass,
4547
)
4648
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
4749
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
4850
KeepDimsFalseToSqueezePass,
4951
)
5052
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
51-
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
53+
from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( # type: ignore[attr-defined]
5254
ConvertMeanDimToAveragePoolPass,
5355
)
54-
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
56+
from executorch.backends.arm._passes.mm_to_bmm_pass import ( # type: ignore[import-not-found]
57+
ConvertMmToBmmPass,
58+
)
5559
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
5660
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
5761
ScalarsToAttributePass,
@@ -89,7 +93,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
8993

9094
self.add_pass(AnnotateDecomposedMatmulPass())
9195
self.add_pass(QuantizeFullArgument())
92-
self.add_pass(FoldAndAnnotateQParamsPass())
96+
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
9397
self.add_pass(RetraceFoldedDtypesPass())
9498
self.add_pass(InsertTableOpsPass(exported_program))
9599

@@ -125,7 +129,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul
125129

126130
self.add_pass(AnnotateDecomposedMatmulPass())
127131
self.add_pass(QuantizeFullArgument())
128-
self.add_pass(FoldAndAnnotateQParamsPass())
132+
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
129133
self.add_pass(RetraceFoldedDtypesPass())
130134
self.add_pass(InsertTableOpsPass(exported_program))
131135

backends/arm/_passes/arm_pass_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
33
# All rights reserved.
44
#
55
# This source code is licensed under the BSD-style license found in the
@@ -58,9 +58,9 @@ def get_param_tensor(
5858
elif is_get_attr_node(node):
5959
# This is a hack to support both lifted and unlifted graph
6060
try:
61-
return getattr(node.graph.owning_module, node.target)
61+
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
6262
except AttributeError:
63-
return getattr(exp_prog.graph_module, node.target)
63+
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
6464
raise RuntimeError(f"unsupported param type, {node.op}.")
6565

6666

@@ -156,7 +156,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
156156
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
157157
)
158158
elif isinstance(key, str):
159-
return args.get(key, default_value) # pyre-ignore[16]
159+
return args.get(key, default_value) # type: ignore[union-attr] # pyre-ignore[16]
160160
elif isclass(key):
161161
for arg in args:
162162
if isinstance(arg, key):

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def fold_and_annotate_arg(
134134
node.meta["input_qparams"][i] = input_qparams
135135
for n in nodes_to_remove:
136136
assert n.target == dq_op
137-
n.replace_all_uses_with(n.args[0])
137+
n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
138138
graph_module.graph.erase_node(n)
139139

140140
def call(self, graph_module: GraphModule) -> PassResult:

backends/arm/_passes/keep_dims_false_to_squeeze_pass.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -66,15 +66,15 @@ def call(self, graph_module: torch.fx.GraphModule):
6666
sum_node = cast(torch.fx.Node, node)
6767
keep_dim = get_node_arg(
6868
# pyre-ignore[6]
69-
sum_node.args,
69+
sum_node.args, # type: ignore[arg-type]
7070
keep_dim_index,
7171
False,
7272
)
7373

7474
if keep_dim:
7575
continue
7676

77-
dim_list = get_node_arg(sum_node.args, 1, [0]) # pyre-ignore[6]
77+
dim_list = get_node_arg(sum_node.args, 1, [0]) # type: ignore[arg-type] # pyre-ignore[6]
7878

7979
# Add keep_dim = True arg to sum node.
8080
set_node_arg(sum_node, 2, True)

backends/arm/_passes/scalars_to_attribute_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -54,7 +54,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
5454
if isinstance(arg, int) and not torch.is_floating_point(
5555
get_first_fake_tensor(n)
5656
):
57-
new_args.append(arg)
57+
new_args.append(arg) # type: ignore[arg-type]
5858
continue
5959

6060
prefix = "_tensor_constant_"

backends/arm/arm_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
from typing import cast, final, List, Optional
1717

18-
import serializer.tosa_serializer as ts
18+
import serializer.tosa_serializer as ts # type: ignore
1919
from executorch.backends.arm.arm_vela import vela_compile
2020
from executorch.backends.arm.operators.node_visitor import get_node_visitors
2121

@@ -230,7 +230,7 @@ def preprocess( # noqa: C901
230230
# Converted output for this subgraph, serializer needs path early as it emits
231231
# const data directly. Path created and data written only in debug builds.
232232
tosa_graph = ts.TosaSerializer(artifact_path)
233-
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline(
233+
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
234234
exported_program=edge_program
235235
)
236236

backends/arm/arm_partitioner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Callable, final, List, Optional, Tuple
1111

1212
import torch
13-
from executorch.backends.arm.arm_backend import (
13+
from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
1414
ArmBackend,
1515
) # usort: skip
1616
from executorch.backends.arm.operator_support.tosa_supported_operators import (

backends/arm/arm_vela.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -12,7 +12,7 @@
1212
from typing import List
1313

1414
import numpy as np
15-
from ethosu.vela import vela
15+
from ethosu.vela import vela # type: ignore
1616

1717

1818
# Pack either input or output tensor block, compose the related arrays into
@@ -96,13 +96,13 @@ def vela_compile(tosa_graph, args: List[str], shape_order=None):
9696
block_name = block_name + b"\x00" * (16 - len(block_name))
9797

9898
# We need the acual unpadded block lengths for hw setup
99-
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)
99+
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0) # type: ignore[assignment]
100100

101101
# Pad block data to multiple of 16 bytes
102102
block_data = bin_blocks[key]
103103
block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)
104104

105-
block = block_name + block_length + block_data
105+
block = block_name + block_length + block_data # type: ignore[operator]
106106
blocks = blocks + block
107107

108108
return blocks

backends/arm/operator_support/to_copy_support.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
9797
if input_dtype not in supported_dtypes:
9898
logger.info(
9999
f"Input dtype {input_val.dtype} is not supported in "
100-
f"{node.target.name()}." # pyre-ignore[16]
100+
f"{node.target.name()}." # type: ignore[union-attr] # pyre-ignore[16]
101101
)
102102
return False
103103

@@ -107,7 +107,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
107107
if output_val.dtype not in supported_dtypes[input_dtype]:
108108
logger.info(
109109
f"Output dtype {output_val.dtype} is not supported in "
110-
f"{node.target.name()} for input dtype {input_dtype}. " # pyre-ignore[16]
110+
f"{node.target.name()} for input dtype {input_dtype}. " # type: ignore[union-attr] # pyre-ignore[16]
111111
f"Supported output types: "
112112
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
113113
)
@@ -118,17 +118,17 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
118118
if node.kwargs["memory_format"] in (torch.preserve_format,):
119119
logger.info(
120120
f"Argument 'memory_format' is not supported for "
121-
f"{node.target.name()} right now." # pyre-ignore[16]
121+
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
122122
)
123123
return False
124124

125125
# Check dim_order (to_dim_order_copy)
126126
if "dim_order" in node.kwargs:
127127
dim_order = node.kwargs["dim_order"]
128-
if dim_order != list(range(len(dim_order))):
128+
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
129129
logger.info(
130130
f"Argument {dim_order=} is not supported for "
131-
f"{node.target.name()} right now." # pyre-ignore[16]
131+
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
132132
)
133133
return False
134134

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,5 +137,5 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
137137
def is_node_supported_custom(self, node: fx.Node) -> bool:
138138
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
139139
if node.target in tosa_checks.keys():
140-
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec)
140+
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index]
141141
return False

backends/arm/operators/node_visitor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -7,7 +7,7 @@
77

88
from typing import Dict, List
99

10-
import serializer.tosa_serializer as ts
10+
import serializer.tosa_serializer as ts # type: ignore
1111
import torch
1212
from executorch.backends.arm.tosa_mapping import TosaArg
1313
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -44,7 +44,7 @@ def define_node(
4444

4545

4646
# container for all node visitors
47-
_node_visitor_dicts = {
47+
_node_visitor_dicts = { # type: ignore[var-annotated]
4848
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
4949
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
5050
}

backends/arm/operators/op_add.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -10,7 +10,7 @@
1010
import executorch.backends.arm.tosa_quant_utils as tqutils
1111
import executorch.backends.arm.tosa_utils as tutils
1212

13-
import serializer.tosa_serializer as ts
13+
import serializer.tosa_serializer as ts # type: ignore
1414
from executorch.backends.arm.operators.node_visitor import (
1515
NodeVisitor,
1616
register_node_visitor,
@@ -75,7 +75,7 @@ def define_node(
7575
if output.dtype == ts.DType.INT8:
7676
# Scale output back to 8 bit
7777
# pyre-ignore
78-
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
78+
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]
7979

8080

8181
@register_node_visitor

backends/arm/operators/op_avg_pool2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
77
from typing import List
88

9-
import serializer.tosa_serializer as ts
9+
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111

1212
# pyre-fixme[21]: ' Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`

backends/arm/operators/op_batch_norm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
# Copyright 2023-2024 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

66
# pyre-unsafe
77
from typing import List
88

9-
import serializer.tosa_serializer as ts
9+
import serializer.tosa_serializer as ts # type: ignore
1010
import torch
1111
from executorch.backends.arm.operators.node_visitor import (
1212
NodeVisitor,

backends/arm/operators/op_bmm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
@@ -7,7 +7,7 @@
77
# pyre-unsafe
88
from typing import List
99

10-
import serializer.tosa_serializer as ts
10+
import serializer.tosa_serializer as ts # type: ignore
1111
import torch
1212

1313
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
@@ -75,14 +75,14 @@ def define_node(
7575
if output.dtype == ts.DType.INT8:
7676
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
7777
final_output_scale = (
78-
input_qparams[0].scale * input_qparams[1].scale # pyre-ignore[61]
78+
input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61]
7979
) / output_qparams.scale
8080

8181
build_rescale(
8282
tosa_fb=tosa_graph,
8383
scale=final_output_scale,
8484
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
85-
input_node=bmm_result,
85+
input_node=bmm_result, # type: ignore[possibly-undefined]
8686
output_name=output.name,
8787
output_type=ts.DType.INT8,
8888
output_shape=bmm_result.shape,

backends/arm/operators/op_cat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -7,7 +7,7 @@
77

88
from typing import List
99

10-
import serializer.tosa_serializer as ts
10+
import serializer.tosa_serializer as ts # type: ignore
1111
from executorch.backends.arm.operators.node_visitor import (
1212
NodeVisitor,
1313
register_node_visitor,

0 commit comments

Comments
 (0)