Skip to content

Lintrunner: Enable mypy testing on backends/arm #7776

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ code = 'MYPY'
include_patterns = [
# TODO(https://github.com/pytorch/executorch/issues/7441): Gradually start enabling all folders.
# 'backends/**/*.py',
'backends/arm/**/*.py',
'build/**/*.py',
'codegen/**/*.py',
# 'devtools/**/*.py',
Expand All @@ -312,6 +313,7 @@ exclude_patterns = [
'**/third-party/**',
'scripts/check_binary_dependencies.py',
'profiler/test/test_profiler_e2e.py',
'backends/arm/test/**',
]
command = [
'python',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def call(self, graph_module: torch.fx.GraphModule):
# dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
dim_order = self.HWCM_order
else:
dim_order = tuple(range(node_data.dim()))
dim_order = tuple(range(node_data.dim())) # type: ignore[assignment]
node.meta["tosa_dim_order"] = dim_order
# Take care of cases when:
# 4D (NHWC) -> >4D (NCH)
Expand Down
18 changes: 11 additions & 7 deletions backends/arm/_passes/arm_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from executorch.backends.arm._passes.convert_split_to_slice import (
ConvertSplitToSlicePass,
)
from executorch.backends.arm._passes.convert_squeezes_to_view import (
from executorch.backends.arm._passes.convert_squeezes_to_view import ( # type: ignore[import-not-found]
ConvertSqueezesToViewPass,
)
from executorch.backends.arm._passes.decompose_div_pass import DecomposeDivPass
Expand All @@ -30,7 +30,9 @@
)
from executorch.backends.arm._passes.decompose_linear_pass import DecomposeLinearPass
from executorch.backends.arm._passes.decompose_meandim_pass import DecomposeMeanDimPass
from executorch.backends.arm._passes.decompose_select import DecomposeSelectPass
from executorch.backends.arm._passes.decompose_select import ( # type: ignore[import-not-found]
DecomposeSelectPass,
)
from executorch.backends.arm._passes.decompose_softmaxes_pass import (
DecomposeSoftmaxesPass,
)
Expand All @@ -40,18 +42,20 @@
QuantizeFullArgument,
RetraceFoldedDtypesPass,
)
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
from executorch.backends.arm._passes.fuse_quantized_activation_pass import ( # type: ignore[import-not-found]
FuseQuantizedActivationPass,
)
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
KeepDimsFalseToSqueezePass,
)
from executorch.backends.arm._passes.match_arg_ranks_pass import MatchArgRanksPass
from executorch.backends.arm._passes.meandim_to_averagepool_pass import (
from executorch.backends.arm._passes.meandim_to_averagepool_pass import ( # type: ignore[attr-defined]
ConvertMeanDimToAveragePoolPass,
)
from executorch.backends.arm._passes.mm_to_bmm_pass import ConvertMmToBmmPass
from executorch.backends.arm._passes.mm_to_bmm_pass import ( # type: ignore[import-not-found]
ConvertMmToBmmPass,
)
from executorch.backends.arm._passes.remove_clone_pass import RemoveClonePass
from executorch.backends.arm._passes.scalars_to_attribute_pass import (
ScalarsToAttributePass,
Expand Down Expand Up @@ -89,7 +93,7 @@ def _tosa_080_BI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(FoldAndAnnotateQParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))

Expand Down Expand Up @@ -125,7 +129,7 @@ def _tosa_080_MI_pipeline(self, exported_program: ExportedProgram) -> GraphModul

self.add_pass(AnnotateDecomposedMatmulPass())
self.add_pass(QuantizeFullArgument())
self.add_pass(FoldAndAnnotateQParamsPass())
self.add_pass(FoldAndAnnotateQParamsPass()) # type: ignore[call-arg]
self.add_pass(RetraceFoldedDtypesPass())
self.add_pass(InsertTableOpsPass(exported_program))

Expand Down
8 changes: 4 additions & 4 deletions backends/arm/_passes/arm_pass_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -58,9 +58,9 @@ def get_param_tensor(
elif is_get_attr_node(node):
# This is a hack to support both lifted and unlifted graph
try:
return getattr(node.graph.owning_module, node.target)
return getattr(node.graph.owning_module, node.target) # type: ignore[arg-type]
except AttributeError:
return getattr(exp_prog.graph_module, node.target)
return getattr(exp_prog.graph_module, node.target) # type: ignore[arg-type]
raise RuntimeError(f"unsupported param type, {node.op}.")


Expand Down Expand Up @@ -156,7 +156,7 @@ def get_node_arg(args: list | dict, key: int | str | type, default_value=None):
f"Out of bounds index {key} for getting value in args (of size {len(args)})"
)
elif isinstance(key, str):
return args.get(key, default_value) # pyre-ignore[16]
return args.get(key, default_value) # type: ignore[union-attr] # pyre-ignore[16]
elif isclass(key):
for arg in args:
if isinstance(arg, key):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def fold_and_annotate_arg(
node.meta["input_qparams"][i] = input_qparams
for n in nodes_to_remove:
assert n.target == dq_op
n.replace_all_uses_with(n.args[0])
n.replace_all_uses_with(n.args[0]) # type: ignore[arg-type]
graph_module.graph.erase_node(n)

def call(self, graph_module: GraphModule) -> PassResult:
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/_passes/keep_dims_false_to_squeeze_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -66,15 +66,15 @@ def call(self, graph_module: torch.fx.GraphModule):
sum_node = cast(torch.fx.Node, node)
keep_dim = get_node_arg(
# pyre-ignore[6]
sum_node.args,
sum_node.args, # type: ignore[arg-type]
keep_dim_index,
False,
)

if keep_dim:
continue

dim_list = get_node_arg(sum_node.args, 1, [0]) # pyre-ignore[6]
dim_list = get_node_arg(sum_node.args, 1, [0]) # type: ignore[arg-type] # pyre-ignore[6]

# Add keep_dim = True arg to sum node.
set_node_arg(sum_node, 2, True)
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/_passes/scalars_to_attribute_pass.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand Down Expand Up @@ -54,7 +54,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
if isinstance(arg, int) and not torch.is_floating_point(
get_first_fake_tensor(n)
):
new_args.append(arg)
new_args.append(arg) # type: ignore[arg-type]
continue

prefix = "_tensor_constant_"
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/arm_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
from typing import cast, final, List, Optional

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.arm_vela import vela_compile
from executorch.backends.arm.operators.node_visitor import get_node_visitors

Expand Down Expand Up @@ -230,7 +230,7 @@ def preprocess( # noqa: C901
# Converted output for this subgraph, serializer needs path early as it emits
# const data directly. Path created and data written only in debug builds.
tosa_graph = ts.TosaSerializer(artifact_path)
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline(
graph_module = ArmPassManager(tosa_spec).transform_to_backend_pipeline( # type: ignore
exported_program=edge_program
)

Expand Down
2 changes: 1 addition & 1 deletion backends/arm/arm_partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Callable, final, List, Optional, Tuple

import torch
from executorch.backends.arm.arm_backend import (
from executorch.backends.arm.arm_backend import ( # type: ignore[attr-defined]
ArmBackend,
) # usort: skip
from executorch.backends.arm.operator_support.tosa_supported_operators import (
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/arm_vela.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -12,7 +12,7 @@
from typing import List

import numpy as np
from ethosu.vela import vela
from ethosu.vela import vela # type: ignore


# Pack either input or output tensor block, compose the related arrays into
Expand Down Expand Up @@ -96,13 +96,13 @@ def vela_compile(tosa_graph, args: List[str], shape_order=None):
block_name = block_name + b"\x00" * (16 - len(block_name))

# We need the acual unpadded block lengths for hw setup
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0)
block_length = struct.pack("<iiii", len(bin_blocks[key]), 0, 0, 0) # type: ignore[assignment]

# Pad block data to multiple of 16 bytes
block_data = bin_blocks[key]
block_data = block_data + b"\x00" * (15 - (len(block_data) - 1) % 16)

block = block_name + block_length + block_data
block = block_name + block_length + block_data # type: ignore[operator]
blocks = blocks + block

return blocks
10 changes: 5 additions & 5 deletions backends/arm/operator_support/to_copy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if input_dtype not in supported_dtypes:
logger.info(
f"Input dtype {input_val.dtype} is not supported in "
f"{node.target.name()}." # pyre-ignore[16]
f"{node.target.name()}." # type: ignore[union-attr] # pyre-ignore[16]
)
return False

Expand All @@ -107,7 +107,7 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if output_val.dtype not in supported_dtypes[input_dtype]:
logger.info(
f"Output dtype {output_val.dtype} is not supported in "
f"{node.target.name()} for input dtype {input_dtype}. " # pyre-ignore[16]
f"{node.target.name()} for input dtype {input_dtype}. " # type: ignore[union-attr] # pyre-ignore[16]
f"Supported output types: "
f"{''.join(str(t) for t in supported_dtypes[input_dtype])}"
)
Expand All @@ -118,18 +118,18 @@ def is_node_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool
if node.kwargs["memory_format"] in (torch.preserve_format,):
logger.info(
f"Argument 'memory_format' is not supported for "
f"{node.target.name()} right now." # pyre-ignore[16]
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
)
return False

# Check dim_order (to_dim_order_copy)
if "dim_order" in node.kwargs:
dim_order = node.kwargs["dim_order"]
# pyre-ignore[6]
if dim_order != list(range(len(dim_order))):
if dim_order != list(range(len(dim_order))): # type: ignore[arg-type]
logger.info(
f"Argument {dim_order=} is not supported for "
f"{node.target.name()} right now." # pyre-ignore[16]
f"{node.target.name()} right now." # type: ignore[union-attr] # pyre-ignore[16]
)
return False

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,5 +137,5 @@ def is_node_supported(self, submodules, node: fx.Node) -> bool:
def is_node_supported_custom(self, node: fx.Node) -> bool:
tosa_checks = get_registered_tosa_support_checks(self.tosa_spec)
if node.target in tosa_checks.keys():
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec)
return tosa_checks[node.target].is_node_supported(node, self.tosa_spec) # type: ignore[index]
return False
6 changes: 3 additions & 3 deletions backends/arm/operators/node_visitor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -7,7 +7,7 @@

from typing import Dict, List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.tosa_mapping import TosaArg
from executorch.backends.arm.tosa_specification import TosaSpecification
Expand Down Expand Up @@ -44,7 +44,7 @@ def define_node(


# container for all node visitors
_node_visitor_dicts = {
_node_visitor_dicts = { # type: ignore[var-annotated]
TosaSpecification.create_from_string("TOSA-0.80+BI"): {},
TosaSpecification.create_from_string("TOSA-0.80+MI"): {},
}
Expand Down
6 changes: 3 additions & 3 deletions backends/arm/operators/op_add.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -10,7 +10,7 @@
import executorch.backends.arm.tosa_quant_utils as tqutils
import executorch.backends.arm.tosa_utils as tutils

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down Expand Up @@ -75,7 +75,7 @@ def define_node(
if output.dtype == ts.DType.INT8:
# Scale output back to 8 bit
# pyre-ignore
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node)
tqutils.insert_rescale_op_to_int8(tosa_graph, add_output, scale_back, node) # type: ignore[possibly-undefined]


@register_node_visitor
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_avg_pool2d.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch

# pyre-fixme[21]: ' Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright 2023-2024 Arm Limited and/or its affiliates.
# Copyright 2023-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
Expand Down
8 changes: 4 additions & 4 deletions backends/arm/operators/op_bmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
Expand All @@ -7,7 +7,7 @@
# pyre-unsafe
from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
import torch

# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
Expand Down Expand Up @@ -75,14 +75,14 @@ def define_node(
if output.dtype == ts.DType.INT8:
output_qparams = get_output_qparams(node)[0] # pyre-ignore[16]
final_output_scale = (
input_qparams[0].scale * input_qparams[1].scale # pyre-ignore[61]
input_qparams[0].scale * input_qparams[1].scale # type: ignore[possibly-undefined] # pyre-ignore[61]
) / output_qparams.scale

build_rescale(
tosa_fb=tosa_graph,
scale=final_output_scale,
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
input_node=bmm_result,
input_node=bmm_result, # type: ignore[possibly-undefined]
output_name=output.name,
output_type=ts.DType.INT8,
output_shape=bmm_result.shape,
Expand Down
4 changes: 2 additions & 2 deletions backends/arm/operators/op_cat.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2024 Arm Limited and/or its affiliates.
# Copyright 2024-2025 Arm Limited and/or its affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
Expand All @@ -7,7 +7,7 @@

from typing import List

import serializer.tosa_serializer as ts
import serializer.tosa_serializer as ts # type: ignore
from executorch.backends.arm.operators.node_visitor import (
NodeVisitor,
register_node_visitor,
Expand Down
Loading