Skip to content

Commit e3a4197

Browse files
committed
Update on "[ET-VK][ez] Test specific sizes of linear sizes in generated operator tests"
## Context Recent changes related to checking SPIR-V capability support at runtime have made it possible to test the 8-bit quantized linear compute shader on Android devices. Previously the test would be automatically skipped since the operator potentially uses 8-bit data types. To make the generated tests more useful, instead test real sizes of linear layer settings found in a sample model in the 8-bit linear test case. Differential Revision: [D68192068](https://our.internmc.facebook.com/intern/diff/D68192068/) [ghstack-poisoned]
2 parents 9ec7e2d + b540e0e commit e3a4197

38 files changed

+745
-205
lines changed

.ci/docker/requirements-ci.txt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
mpmath==1.3.0
2-
numpy==1.21.3; python_version == '3.10'
3-
numpy==1.23.2; python_version == '3.11'
4-
numpy; python_version >= '3.12'
2+
numpy==2.0.0; python_version >= '3.10'
53
PyYAML==6.0.1
64
ruamel.yaml==0.17.32
75
sympy==1.12
86
timm==0.6.13
97
tomli==2.0.1
108
torchsr==1.0.4
11-
transformers==4.38.0
9+
transformers==4.47.1
1210
zstd==1.5.5.1
13-
pandas==2.0.3; python_version == '3.10'
14-
pandas; python_version >= '3.11'
11+
pandas==2.2.2; python_version >= '3.10'
1512
pytest==7.2.0
1613
pytest-cov==4.1.0
1714
expecttest==0.1.6
@@ -24,7 +21,7 @@ sphinx-gallery==0.14.0
2421
breathe==4.34.0
2522
exhale==0.2.3
2623
docutils==0.16
27-
matplotlib==3.7.2
24+
matplotlib==3.9.4
2825
# PyTorch Theme
2926
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
3027
myst-parser==0.18.1

.github/workflows/android-perf.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ jobs:
260260
--output_name="${OUT_ET_MODEL_NAME}.pte"
261261
ls -lh "${OUT_ET_MODEL_NAME}.pte"
262262
elif [[ ${{ matrix.config }} == "llama3_qnn_htp" ]]; then
263-
export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728
263+
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
264264
export LD_LIBRARY_PATH=$QNN_SDK_ROOT/lib/x86_64-linux-clang/
265265
export PYTHONPATH=$(pwd)/..
266266
@@ -347,7 +347,7 @@ jobs:
347347
PYTHON_EXECUTABLE=python bash .ci/scripts/build-qnn-sdk.sh
348348
349349
export ANDROID_ABIS="arm64-v8a"
350-
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
350+
PYTHON_EXECUTABLE=python EXECUTORCH_BUILD_QNN=ON QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029 bash build/build_android_llm_demo.sh ${ARTIFACTS_DIR_NAME}
351351
352352
# Let's see how expensive this job is, we might want to tone it down by running it periodically
353353
benchmark-on-device:

backends/apple/coreml/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,14 +93,14 @@ class Model(torch.nn.Module):
9393
source_model = Model()
9494
example_inputs = (torch.randn((1, 3, 256, 256)), )
9595

96-
pre_autograd_aten_dialect = export_for_training(model, example_inputs).module()
96+
pre_autograd_aten_dialect = export_for_training(source_model, example_inputs).module()
9797

9898
quantization_config = LinearQuantizerConfig.from_dict(
9999
{
100100
"global_config": {
101101
"quantization_scheme": QuantizationScheme.symmetric,
102-
"activation_dtype": torch.uint8,
103-
"weight_dtype": torch.int8,
102+
"activation_dtype": torch.quint8,
103+
"weight_dtype": torch.qint8,
104104
"weight_per_channel": True,
105105
}
106106
}

backends/apple/coreml/scripts/install_requirements.sh

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,7 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel
4747

4848
echo "${green}ExecuTorch: Installing coremltools."
4949
pip install "$COREMLTOOLS_DIR_PATH"
50-
# CoreMLTools have started supporting numpy 2.0,
51-
# but ExecuTorch example model test env is still using older transformers,
52-
# so for now we will need to downgrade numpy to 1.x
53-
# TODO: Remove this numpy downgrade once later transformers starts to be used
54-
pip install numpy==1.26.4
50+
5551
STATUS=$?
5652
if [ $STATUS -ne 0 ]; then
5753
echo "${red}ExecuTorch: Failed to install coremltools."

backends/arm/_passes/size_adjust_conv2d_pass.py

Lines changed: 51 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,73 +1,74 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
# All rights reserved.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

77
# pyre-unsafe
88

9-
from typing import cast, Optional
9+
from typing import cast
1010

1111
import torch.fx
12+
from executorch.backends.arm._passes.arm_pass_utils import create_node
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
14-
from torch._ops import OpOverload
1515

1616

1717
def conv_remainder(input_length, pad, dilation, weight, stride):
1818
"""
19-
Returns the size
19+
Returns the remainder of input_length; given the padding, dilation, stride,
20+
and kernel size.
2021
"""
2122
return (input_length + 2 * pad - dilation * (weight - 1) - 1) % stride
2223

2324

24-
def insert_q_dq_pair(
25-
graph: torch.fx.Graph,
26-
anchor: torch.fx.Node,
27-
q_params: tuple,
28-
):
29-
with graph.inserting_after(anchor):
30-
q = create_node(
31-
graph=graph,
32-
op_target=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
33-
args=(), # We add the argument last
34-
)
35-
q.meta = anchor.meta
36-
37-
with graph.inserting_after(q):
38-
dq = create_node(
39-
graph=graph,
40-
op_target=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
41-
args=(q,) + q_params,
42-
)
43-
dq.meta = q.meta
44-
45-
anchor.replace_all_uses_with(dq)
46-
# We add this last so the replace all uses above does not replace the quantized
47-
# node's first use
48-
q.args = (anchor,) + q_params
49-
return dq
50-
51-
52-
def create_node(
53-
graph: torch.fx.Graph,
54-
op_target: OpOverload,
55-
args: tuple = (),
56-
kwargs: Optional[dict] = None,
57-
):
58-
return graph.create_node(
59-
"call_function",
60-
op_target,
61-
args=args,
62-
kwargs=kwargs or {},
63-
)
64-
65-
6625
class SizeAdjustConv2DPass(ExportPass):
6726
"""
68-
Adjust the convolution input size to match perfectly with the
69-
weight size, padding, stride and dilation parameters.
70-
This is done by inserting a slice op to remove the uneven end of the input.
27+
Adjust the convolution input size to match the kernel size, padding, stride,
28+
and dilation parameters. Pytorch allows the input and kernel shape to not
29+
"match", in which case the remaining rows/columns are truncated. However,
30+
matching the size is a requirement in the TOSA specification. In case the
31+
input and kernel shape do not match, the following is done to meet the
32+
specification:
33+
34+
1) The padding is truncated (done in the node visitor)
35+
2) (if neccessary) The input is truncated (done in this pass)."
36+
37+
A simple example would be a 2x2 kernel (no padding, stride=2) and a 5x5
38+
input:
39+
40+
┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐
41+
│ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │
42+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
43+
│ X │ X │ │ │ │ │ │ │ X │ X │ │ │ │ │ │ │ - │
44+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
45+
│ │ │ │ │ │ -> │ │ │ │ │ │ -> │ X │ X │ │ │ │ ->
46+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
47+
│ │ │ │ │ │ │ │ │ │ │ │ │ X │ X │ │ │ │
48+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
49+
│ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │ │
50+
└───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘
51+
First pass second pass third pass
52+
53+
┌───┬───┬───┬───┬───┐ ┌───┬───┬───┬───┬───┐
54+
│ │ │ │ │ │ │ │ │ │ │ - │
55+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
56+
│ │ │ │ │ │ │ │ │ │ │ - │
57+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
58+
│ │ │ X │ X │ │ -> │ │ │ │ │ - │
59+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
60+
│ │ │ X │ X │ │ │ │ │ │ │ - │
61+
├───┼───┼───┼───┼───┤ ├───┼───┼───┼───┼───┤
62+
│ │ │ │ │ │ │ - │ - │ - │ - │ - │
63+
└───┴───┴───┴───┴───┘ └───┴───┴───┴───┴───┘
64+
Fourth pass Unvisited cells
65+
66+
Cells that are never visited are marked with `-` and are never considered
67+
when the kernel traverses over the input, hence they can be removed.
68+
69+
To match the shape of the kernel (and all parameters) with the input, a
70+
slice op is inserted to remove the remaining edges (rows and columns) of the
71+
input.
7172
"""
7273

7374
conv2d_op = exir_ops.edge.aten.convolution.default
@@ -109,9 +110,7 @@ def call(self, graph_module: torch.fx.GraphModule):
109110
with graph_module.graph.inserting_before(node):
110111
last_node = cast(torch.fx.Node, input_node)
111112
for args in slice_args:
112-
slice_node = graph.create_node(
113-
"call_function", self.slice_op, (last_node,) + args
114-
)
113+
slice_node = create_node(graph, self.slice_op, (last_node,) + args)
115114
last_node = slice_node
116115
conv_node.replace_input_with(cast(torch.fx.Node, input_node), last_node)
117116
modified_graph = True

backends/arm/process_node.py

Lines changed: 8 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -11,11 +11,6 @@
1111
import serializer.tosa_serializer as ts
1212
import torch
1313
import torch.fx
14-
15-
# pyre-fixme[21]: 'Could not find a module corresponding to import `executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass`.'
16-
from executorch.backends.arm._passes.fold_qdq_with_annotated_qparams_pass import (
17-
get_input_qparams,
18-
)
1914
from executorch.backends.arm.operators.node_visitor import NodeVisitor
2015
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
2116
from executorch.backends.arm.tosa_quant_utils import (
@@ -24,11 +19,7 @@
2419
is_node_quantized,
2520
)
2621
from executorch.backends.arm.tosa_specification import TosaSpecification
27-
from executorch.backends.arm.tosa_utils import (
28-
getNodeArgs,
29-
is_bias_node_for_quantized_conv,
30-
tosa_shape,
31-
)
22+
from executorch.backends.arm.tosa_utils import getNodeArgs, tosa_shape
3223
from torch.export.exported_program import ExportedProgram
3324

3425

@@ -99,41 +90,6 @@ def process_inputs(
9990
tosa_graph.addInputTensor(tensor)
10091

10192

102-
def process_quantized_bias(
103-
node: torch.fx.Node,
104-
tosa_graph: ts.TosaSerializer,
105-
parameter_values,
106-
):
107-
"""
108-
Serialize bias node that needs to be quantized.
109-
"""
110-
consumer_node = list(node.users)[0]
111-
(
112-
input_node,
113-
weight_node,
114-
_,
115-
) = consumer_node.all_input_nodes
116-
117-
input_qargs = get_input_qparams( # pyre-ignore[16]: Module `executorch.backends.arm` has no attribute `_passes`.
118-
consumer_node
119-
)
120-
121-
input_node_scale = input_qargs[0].scale
122-
weight_node_scale = input_qargs[1].scale
123-
bias_values_quantized = (
124-
(parameter_values / (input_node_scale * weight_node_scale))
125-
.round()
126-
.astype(np.int32)
127-
)
128-
129-
tosa_graph.addConst(
130-
bias_values_quantized.shape,
131-
ts.DType.INT32,
132-
bias_values_quantized,
133-
name=node.name,
134-
)
135-
136-
13793
def process_inputs_to_parameters(
13894
node: torch.fx.Node,
13995
tosa_graph: ts.TosaSerializer,
@@ -148,20 +104,14 @@ def process_inputs_to_parameters(
148104
assert isinstance(parameter_data, torch.Tensor), "Expect Attr to be tensor"
149105
parameter_values = parameter_data.detach().numpy()
150106

151-
if is_bias_node_for_quantized_conv(node):
152-
# BI bias
153-
assert tosa_spec.support_integer(), f"{tosa_spec} doesnt't support integer"
154-
process_quantized_bias(node, tosa_graph, parameter_values)
155-
else:
156-
# MI weights or bias
157-
if inputs[0].dtype == torch.float32:
158-
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
107+
if inputs[0].dtype == torch.float32:
108+
assert tosa_spec.support_float(), f"{tosa_spec} doesn't support float"
159109

160-
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
110+
parameter_values = np.transpose(parameter_values, inputs[0].dim_order)
161111

162-
tosa_graph.addConst(
163-
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
164-
)
112+
tosa_graph.addConst(
113+
parameter_values.shape, inputs[0].dtype, parameter_values, name=node.name
114+
)
165115

166116

167117
def process_inputs_to_buffers(

backends/arm/quantizer/quantization_annotator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -196,7 +196,7 @@ def get_quant_properties( # noqa: C901
196196
input_act_qspec = quantization_config.get_input_act_qspec()
197197
weight_qspec = quantization_config.get_weight_qspec()
198198
output_act_qspec = quantization_config.get_output_act_qspec()
199-
bias_qspec = quantization_config.get_bias_qspec()
199+
bias_qspec = quantization_config.get_bias_qspec(node)
200200

201201
quant_properties = _OpQuantProperties()
202202

backends/arm/quantizer/quantization_config.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024 Arm Limited and/or its affiliates.
2+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
33
#
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
@@ -9,8 +9,10 @@
99
from dataclasses import dataclass
1010

1111
import torch
12+
from torch.ao.quantization import ObserverOrFakeQuantize
1213

1314
from torch.ao.quantization.quantizer import (
15+
DerivedQuantizationSpec,
1416
FixedQParamsQuantizationSpec,
1517
QuantizationSpec,
1618
)
@@ -53,8 +55,42 @@ def get_weight_qspec(self) -> QuantizationSpec | None:
5355
], f"Unsupported quantization_spec {self.weight} for weight"
5456
return self.weight
5557

56-
def get_bias_qspec(self) -> QuantizationSpec | None:
58+
def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None:
5759
"""Returns QuantizationSpec 'bias' after asserting that bias.dtype is torch.float."""
60+
61+
def _derive_qparams_fn(
62+
obs_or_fqs: list[ObserverOrFakeQuantize],
63+
) -> tuple[torch.Tensor, torch.Tensor]:
64+
assert (
65+
len(obs_or_fqs) == 2
66+
), "Expecting two obs/fqs, one for activation and one for weight, got: {}".format(
67+
len(obs_or_fqs)
68+
)
69+
act_obs_or_fq = obs_or_fqs[0]
70+
weight_obs_or_fq = obs_or_fqs[1]
71+
act_scale, act_zp = act_obs_or_fq.calculate_qparams()
72+
weight_scale, weight_zp = weight_obs_or_fq.calculate_qparams()
73+
return torch.tensor([act_scale * weight_scale]).to(
74+
torch.float32
75+
), torch.tensor([0]).to(torch.int32)
76+
77+
if node.target in [
78+
torch.ops.aten.conv1d.default,
79+
torch.ops.aten.conv2d.default,
80+
torch.ops.aten.linear.default,
81+
]:
82+
input_act = node.args[0]
83+
weight = node.args[1]
84+
quantization_spec = DerivedQuantizationSpec(
85+
derived_from=[(input_act, node), (weight, node)],
86+
derive_qparams_fn=_derive_qparams_fn,
87+
dtype=torch.int32,
88+
quant_min=torch.iinfo(torch.int32).min,
89+
quant_max=torch.iinfo(torch.int32).max - 1,
90+
qscheme=torch.per_tensor_symmetric,
91+
)
92+
return quantization_spec
93+
5894
if self.bias is None:
5995
return None
6096
assert (

0 commit comments

Comments
 (0)