Skip to content

Commit 3082691

Browse files
authored
Merge branch 'main' into vulkan-unused-lib
2 parents 1dcc9a5 + 319e88d commit 3082691

23 files changed

+493
-149
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ endif()
608608
# any backends.
609609
#
610610
add_library(executorch ${_executorch__srcs})
611-
target_link_libraries(executorch PUBLIC executorch_core)
611+
target_link_libraries(executorch PRIVATE executorch_core)
612612
target_include_directories(executorch PUBLIC ${_common_include_directories})
613613
target_compile_definitions(executorch PUBLIC C10_USING_CUSTOM_GENERATED_MACROS)
614614
target_compile_options(executorch PUBLIC ${_common_compile_options})

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@
5757
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
5858
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
5959
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
60+
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
6061
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7070
if quantized_input:
7171
matmul_args = matmul_node.all_input_nodes
7272
for node in matmul_args:
73+
# Find the dq-node connected to this mm/bmm arg
7374
input_node = self._match_partition_to_node(
7475
node, partition.input_nodes
7576
)
76-
77-
# Remove partition input dq-node
78-
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
79-
graph_module.graph.erase_node(input_node)
8077
input_node_qargs = QuantArgs.from_operator(
8178
input_node.target, input_node.args
8279
)
83-
80+
# Insert new dq-node just before the mm/bmm with input_node's qparams
8481
with graph_module.graph.inserting_before(matmul_node):
8582
# Create new dq-node before matmul
8683
dq_node = create_node(
@@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
9087
dq_node.args = (node, *input_node_qargs)
9188
matmul_node.replace_input_with(node, dq_node)
9289

90+
for partition_input in partition.input_nodes:
91+
# Remove partition input dq-node
92+
partition_input.replace_all_uses_with(
93+
partition_input.all_input_nodes[0]
94+
)
95+
graph_module.graph.erase_node(partition_input)
96+
9397
partition_output = list(partition.output_nodes[0].users)[0]
9498
quantized_output = partition_output.target == q_op
9599
if quantized_output:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
MatchWhereSelfDtypePass,
5050
QuantizeOperatorArguments,
5151
RemoveClonePass,
52+
ReplaceInfValues,
5253
ReplaceScalarWithTensorArgPassTOSABI,
5354
ReplaceScalarWithTensorArgPassTOSAMI,
5455
RetraceFoldedDtypesPass,
@@ -216,4 +217,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
216217
self.add_pass(DecomposeSoftmaxPass())
217218

218219
self.add_pass(ConvertMinMaxPass())
220+
self.add_pass(ReplaceInfValues())
219221
return self._transform(graph_module)

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import torch.fx
10-
from executorch.backends.arm._passes.arm_pass_utils import create_node
11-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
create_node,
11+
get_first_fake_tensor,
12+
)
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415

@@ -34,7 +35,7 @@ def call(self, graph_module: torch.fx.GraphModule):
3435
split_node = node
3536
input_node = split_node.all_input_nodes[0]
3637
output_nodes = split_node.users.copy()
37-
_, shape, _ = extract_tensor_meta(input_node.meta)
38+
shape = get_first_fake_tensor(input_node).shape
3839
rank = len(shape)
3940
split_lengths = split_node.args[1]
4041
dim = split_node.args[2] if len(split_node.args) > 2 else 0
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This pass is based on backends/qualcomm/_passes/replace_inf_values.py
8+
# with some modification to replaced inf values.
9+
10+
import torch
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class ReplaceInfValues(ExportPass):
15+
"""
16+
Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values.
17+
"""
18+
19+
def __init__(self):
20+
super(ReplaceInfValues, self).__init__()
21+
22+
def call(self, graph_module: torch.fx.GraphModule):
23+
modified = False
24+
for buf_name, tensor in graph_module.named_buffers():
25+
if tensor.is_floating_point():
26+
modified = True
27+
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
28+
tensor[tensor == float("inf")] = 255
29+
tensor[tensor == float("-inf")] = -255
30+
setattr(graph_module, buf_name, tensor)
31+
32+
for node in graph_module.graph.nodes:
33+
arg_list = list(node.args)
34+
for index, arg in enumerate(arg_list):
35+
if arg == float("-inf"):
36+
modified = True
37+
arg_list[index] = -255
38+
elif arg == float("inf"):
39+
modified = True
40+
arg_list[index] = +255
41+
node.args = tuple(arg_list)
42+
43+
if modified:
44+
graph_module.recompile()
45+
return PassResult(graph_module, modified)

backends/arm/operator_support/slice_copy_support.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
SupportedTOSAOperatorCheck,
1313
)
1414
from executorch.backends.arm.tosa_specification import TosaSpecification
15-
from executorch.backends.arm.tosa_utils import getNodeArgs
1615
from executorch.exir.dialects._ops import ops as exir_ops
1716

1817
logger = logging.getLogger(__name__)
@@ -33,8 +32,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) ->
3332
if tosa_spec not in self.tosa_specs:
3433
return False
3534

36-
inputs = getNodeArgs(node)
37-
if len(inputs) == 5 and (step := inputs[4].number) != 1:
35+
args = node.args
36+
if len(args) == 5 and (step := args[4]) != 1:
3837
logging.warning(f"{node.target} with step size of {step} not supported.")
3938
return False
4039
return True

backends/arm/operators/op_rescale.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
NodeVisitor,
1414
register_node_visitor,
1515
)
16-
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
16+
from executorch.backends.arm.tosa_mapping import TosaArg
1717
from executorch.backends.arm.tosa_quant_utils import create_const_ops_for_rescale
1818

1919
from executorch.backends.arm.tosa_specification import TosaSpecification
@@ -35,15 +35,15 @@ def define_node(
3535
) -> None:
3636
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
3737

38-
input_dtype = inputs[0].dtype
38+
input_dtype = node.all_input_nodes[0].meta["val"].dtype
3939
output_dtype = cast(torch.dtype, node.args[1])
4040
scale = cast(float, node.args[2])
4141
input_zp = cast(int, node.args[3])
4242
output_zp = cast(int, node.args[4])
4343

44-
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
44+
if input_dtype != torch.int8 and input_zp != 0:
4545
raise ValueError(
46-
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
46+
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
4747
)
4848
if output_dtype != torch.int8 and output_zp != 0:
4949
raise ValueError(
@@ -91,15 +91,15 @@ def define_node(
9191
import serializer.tosa_serializer as ts # type: ignore
9292
from tosa.RoundingMode import RoundingMode # type: ignore
9393

94-
input_dtype = inputs[0].dtype
94+
input_dtype = node.all_input_nodes[0].meta["val"].dtype
9595
output_dtype = cast(torch.dtype, node.args[1])
9696
scale = cast(float, node.args[2])
9797
input_zp = cast(int, node.args[3])
9898
output_zp = cast(int, node.args[4])
9999

100-
if input_dtype != map_dtype(torch.int8) and input_zp != 0:
100+
if input_dtype != torch.int8 and input_zp != 0:
101101
raise ValueError(
102-
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{ts.DTypeNames[input_dtype]}, {input_zp=}"
102+
f"If input dtype is not int8, input_zp must be 0. Got input_dtype{input_dtype=}, {input_zp=}"
103103
)
104104
if output_dtype != torch.int8 and output_zp != 0:
105105
raise ValueError(

backends/arm/process_node.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ def process_call_function(
3636
tosa_spec: TosaSpecification,
3737
):
3838
# Unpack arguments and convert
39-
inputs = getNodeArgs(node)
39+
inputs = getNodeArgs(node, tosa_spec)
4040

4141
# Convert output (this node itself)
4242
try:
43-
output = TosaArg(node)
43+
output = TosaArg(node, tosa_spec)
4444
except ValueError as e:
4545
raise ValueError(
4646
f"Failed processing call_function: {node.name}. "
@@ -78,7 +78,7 @@ def process_inputs(
7878
f"Expected dim_order: {tuple(range(meta.dim()))}, but got: {meta.dim_order()} for node {node.name}"
7979
)
8080
try:
81-
tosa_arg = TosaArg(node)
81+
tosa_arg = TosaArg(node, tosa_spec)
8282
except ValueError as e:
8383
raise ValueError(
8484
f"Failed processing input placeholder: {node.name}. "
@@ -112,7 +112,7 @@ def process_inputs_to_parameters(
112112
):
113113
"""Serialize bias and non-quantized weights"""
114114
try:
115-
tosa_arg = TosaArg(node)
115+
tosa_arg = TosaArg(node, tosa_spec)
116116
except ValueError as e:
117117
raise ValueError(
118118
f"Failed processing parameter placeholder: {node.name}. "
@@ -137,10 +137,11 @@ def process_inputs_to_buffers(
137137
node: torch.fx.Node,
138138
tosa_graph: Any,
139139
edge_program: ExportedProgram,
140+
tosa_spec: TosaSpecification,
140141
):
141142
"""Serialize quantized weights"""
142143
try:
143-
tosa_arg = TosaArg(node)
144+
tosa_arg = TosaArg(node, tosa_spec)
144145
except ValueError as e:
145146
raise ValueError(
146147
f"Failed processing buffer placeholder: {node.name}. "
@@ -165,9 +166,10 @@ def process_inputs_to_lifted_tensor_constants(
165166
node: torch.fx.Node,
166167
tosa_graph: Any,
167168
edge_program: ExportedProgram,
169+
tosa_spec: TosaSpecification,
168170
):
169171
try:
170-
tosa_arg = TosaArg(node)
172+
tosa_arg = TosaArg(node, tosa_spec)
171173
except ValueError as e:
172174
raise ValueError(
173175
f"Failed processing lifted tensor constant placeholder: {node.name}. "
@@ -196,9 +198,11 @@ def process_placeholder(
196198
elif is_param(edge_program, node):
197199
process_inputs_to_parameters(node, tosa_graph, edge_program, tosa_spec)
198200
elif is_buffer(edge_program, node):
199-
process_inputs_to_buffers(node, tosa_graph, edge_program)
201+
process_inputs_to_buffers(node, tosa_graph, edge_program, tosa_spec)
200202
elif is_lifted_tensor_constant(edge_program, node):
201-
process_inputs_to_lifted_tensor_constants(node, tosa_graph, edge_program)
203+
process_inputs_to_lifted_tensor_constants(
204+
node, tosa_graph, edge_program, tosa_spec
205+
)
202206
elif node.name in edge_program.graph_signature.inputs_to_lifted_custom_objs:
203207
raise NotImplementedError(
204208
"Placeholder is of type 'lifted custom object' which is not supported."

backends/arm/quantizer/quantization_annotator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,9 @@ def any_or_hardtanh_min_zero(n: Node):
411411
shared_qspec = SharedQuantizationSpec(node.args[0])
412412
quant_properties.quant_inputs = [_QuantProperty(0, shared_qspec)] # type: ignore[arg-type]
413413
quant_properties.quant_output = _QuantProperty(0, shared_qspec) # type: ignore[arg-type]
414+
elif node.target in [torch.ops.aten.scalar_tensor.default]:
415+
quant_properties.quant_inputs = []
416+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
414417
else:
415418
return None
416419

@@ -458,5 +461,6 @@ def annotate_graph( # type: ignore[return]
458461
if node.target in [
459462
torch.ops.aten.full_like.default,
460463
torch.ops.aten.full.default,
464+
torch.ops.aten.scalar_tensor.default,
461465
]:
462466
node.kwargs = {}

backends/arm/test/models/test_llama.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def test_llama_tosa_MI(self):
105105
)
106106
)
107107

108-
@pytest.mark.xfail(reason="KeyError: scalar_tensor_1 (MLETORCH-907)")
109108
def test_llama_tosa_BI(self):
110109
llama_model, llama_inputs, llama_meta = self.prepare_model()
111110

@@ -126,7 +125,7 @@ def test_llama_tosa_BI(self):
126125
.to_executorch()
127126
.run_method_and_compare_outputs(
128127
inputs=llama_inputs,
129-
atol=4.3,
130-
rtol=1.1, # TODO: Tolerance needs to be updated after MLETORCH-907
128+
atol=9.9,
129+
rtol=1.5, # TODO: Tolerance needs to be updated after MLETORCH-907
131130
)
132131
)

backends/arm/test/ops/test_bmm.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,6 @@ class BMM(torch.nn.Module):
3232
def forward(self, x, y):
3333
return torch.bmm(x, y)
3434

35-
class MatMul(torch.nn.Module):
36-
test_data_generators = [
37-
lambda: (torch.rand(2, 3, 5), torch.rand(2, 5, 2)),
38-
lambda: (torch.rand(1, 2, 3, 5), torch.rand(1, 2, 5, 2)),
39-
]
40-
41-
def forward(self, x, y):
42-
return torch.matmul(x, y)
43-
4435
class BMMSingleInput(torch.nn.Module):
4536
test_data_generators = [
4637
lambda: (torch.rand(20, 3, 3),),
@@ -129,16 +120,6 @@ def test_bmm_single_input_tosa_MI(self, test_data_generator: Callable[[], Tuple]
129120
test_data = test_data_generator()
130121
self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
131122

132-
@parameterized.expand(MatMul.test_data_generators)
133-
def test_matmul_tosa_MI(self, test_data_generator: Callable[[], Tuple]):
134-
test_data = test_data_generator()
135-
self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)
136-
137-
@parameterized.expand(MatMul.test_data_generators)
138-
def test_matmul_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
139-
test_data = test_data_generator()
140-
self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)
141-
142123
@parameterized.expand(BMM.test_data_generators)
143124
def test_bmm_tosa_BI(self, test_data_generator: Callable[[], Tuple]):
144125
test_data = test_data_generator()

0 commit comments

Comments
 (0)