Skip to content

Commit 9d84a42

Browse files
committed
Merge branch 'main' into jz/fix-prefill
2 parents 5c53856 + 25d8f15 commit 9d84a42

File tree

13 files changed

+184
-14
lines changed

13 files changed

+184
-14
lines changed

.ci/docker/ci_commit_pins/pytorch.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2ea4b56ec872424e486c4fe2d55da061067a2ed3
1+
0a94bb432ed75cc2d950d81b2921363218a7e459

backends/arm/README.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,28 @@ The you can run the tests with
122122
pytest -c /dev/null -v -n auto backends/arm/test --arm_quantize_io --arm_run_corstoneFVP
123123
```
124124

125+
### Code coverage
126+
127+
To get code coverage:
128+
129+
```
130+
coverage run --source=<SRC> --rcfile=backends/arm/test/.coveragerc -m pytest \
131+
--config-file=/dev/null backends/arm/test/
132+
```
133+
134+
All files in `SRC` and its child directories will be analysed for code coverage,
135+
unless explicitly exluded in the .coveragerc file. If using venv this might be
136+
under `env/lib/python<VERSION_NUMBER>/site-packages/executorch/`. To get the
137+
absolute path, run:
138+
139+
```
140+
python -c "import executorch; print(executorch.__path__)"
141+
```
142+
143+
This contains a list of paths where the source directory is located. Pick the
144+
one that is located in `env/lib`. If that does not work try the others. Add
145+
`backends/arm` to the path in `--source` to only get code coverage for the Arm
146+
backend.
125147

126148
### A note on unit tests
127149

backends/arm/_passes/arm_pass_manager.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
QuantizeFullArgument,
3838
RetraceFoldedDtypesPass,
3939
)
40+
from executorch.backends.arm._passes.fuse_quantized_activation_pass import (
41+
FuseQuantizedActivationPass,
42+
)
4043
from executorch.backends.arm._passes.insert_table_ops import InsertTableOpsPass
4144
from executorch.backends.arm._passes.keep_dims_false_to_squeeze_pass import (
4245
KeepDimsFalseToSqueezePass,
@@ -73,6 +76,7 @@ def transform_to_backend_pipeline(
7376
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
7477
):
7578
"""Apply passes before transforming program to backend"""
79+
self.add_pass(FuseQuantizedActivationPass())
7680
self.add_pass(DecomposeLinearPass())
7781
self.add_pass(RemoveGetItemPass())
7882
self.add_pass(DecomposeLayerNormPass())
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.backends.arm.tosa_quant_utils import q_op
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
from executorch.exir.pass_base import ExportPass, PassResult
10+
from torch.fx import Node
11+
12+
13+
class FuseQuantizedActivationPass(ExportPass):
14+
def _is_fuseable_quantized_activation(self, node: Node):
15+
"""Fuse activations that have a 0 lower bound and quantized with a qmin zero-point"""
16+
is_fuseable = node.target == exir_ops.edge.aten.relu.default
17+
if node.target == exir_ops.edge.aten.hardtanh.default:
18+
min_val = node.args[1]
19+
is_fuseable = min_val == 0
20+
21+
is_quantized = len(node.users) == 1 and next(iter(node.users)).target == q_op
22+
if is_quantized:
23+
quant_node = next(iter(node.users))
24+
zp = quant_node.args[2]
25+
qmin = quant_node.args[3]
26+
27+
return is_fuseable and is_quantized and zp == qmin
28+
29+
def _is_fuseable_input(self, node: Node):
30+
return (
31+
node.target
32+
in (
33+
exir_ops.edge.aten.convolution.default,
34+
exir_ops.edge.aten.linear.default,
35+
)
36+
and len(node.users) == 1
37+
)
38+
39+
def call(self, graph_module: torch.fx.GraphModule):
40+
modified = False
41+
for node in graph_module.graph.nodes:
42+
if node.op != "call_function":
43+
continue
44+
45+
if not self._is_fuseable_quantized_activation(node):
46+
continue
47+
48+
input_node = node.args[0]
49+
if not self._is_fuseable_input(input_node):
50+
continue
51+
52+
node.replace_all_uses_with(input_node)
53+
graph_module.graph.erase_node(node)
54+
modified = True
55+
56+
if modified:
57+
graph_module.recompile()
58+
graph_module = super().call(graph_module).graph_module
59+
60+
return PassResult(graph_module, modified)

backends/arm/quantizer/quantization_annotator.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,41 @@ def _annotate_output(node: Node, quant_property: _QuantProperty):
8989
_annotate_output_qspec(node, quant_property.qspec)
9090

9191

92+
def _match_pattern(
93+
node: Node, pattern: List[List], filter_fn: Optional[Callable[[Node], bool]] = None
94+
) -> bool:
95+
"""
96+
Check if there's a chain of node.ancestors? -> node -> node.descendant? that matches the
97+
chain provided in 'pattern'. If 'filter_fn' is provided, check that all the nodes in the
98+
chain pass the filtering.
99+
100+
Each 'pattern' element is composed of a list of disjunctive nodes types.
101+
"""
102+
assert len(pattern) == 2, "Only two-nodes patterns supported currently"
103+
104+
if node.target in pattern[0]:
105+
assert len(node.users) != 0
106+
parent = node
107+
child = next(iter(node.users))
108+
elif node.target in pattern[1]:
109+
assert len(node.args) != 0
110+
parent = node.args[0]
111+
child = node
112+
else:
113+
return False
114+
115+
if len(parent.users) != 1:
116+
return False
117+
118+
if parent.target not in pattern[0] or child.target not in pattern[1]:
119+
return False
120+
121+
if filter_fn is not None:
122+
return filter_fn(parent) and filter_fn(child)
123+
124+
return True
125+
126+
92127
_one_to_one = [
93128
torch.ops.aten.exp.default,
94129
torch.ops.aten.log.default,
@@ -164,7 +199,36 @@ def get_quant_properties( # noqa: C901
164199
bias_qspec = quantization_config.get_bias_qspec()
165200

166201
quant_properties = _OpQuantProperties()
167-
if node.target in (
202+
203+
def any_or_hardtanh_min_zero(n: Node):
204+
# Check that if the node is a hardtanh, its min_val is zero
205+
return n.target != torch.ops.aten.hardtanh.default or n.args[1] == 0
206+
207+
if _match_pattern(
208+
node,
209+
[
210+
[
211+
torch.ops.aten.conv1d.default,
212+
torch.ops.aten.conv2d.default,
213+
torch.ops.aten.linear.default,
214+
],
215+
[torch.ops.aten.relu.default, torch.ops.aten.hardtanh.default],
216+
],
217+
any_or_hardtanh_min_zero,
218+
):
219+
if node.target in (
220+
torch.ops.aten.conv1d.default,
221+
torch.ops.aten.conv2d.default,
222+
torch.ops.aten.linear.default,
223+
):
224+
quant_properties.quant_inputs = [
225+
_QuantProperty(0, input_act_qspec),
226+
_QuantProperty(1, weight_qspec, mark_annotated=True),
227+
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
228+
]
229+
else:
230+
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
231+
elif node.target in (
168232
torch.ops.aten.conv1d.default,
169233
torch.ops.aten.conv2d.default,
170234
torch.ops.aten.linear.default,

backends/arm/test/.coveragerc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
[run]
2+
omit =
3+
*__init__.py*
4+
5+
[report]
6+
skip_covered = true
7+
exclude_also =
8+
raise NotImplementedError

backends/arm/test/ops/test_conv_combos.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,11 @@ class ComboConvRelu6(torch.nn.Module):
137137
]
138138

139139
test_data = [
140-
(20 * torch.randn(1, 3, 256, 256),),
141-
(5 * torch.randn(1, 3, 256, 256),),
140+
(2 * torch.randn(1, 3, 256, 256),),
141+
(0.5 * torch.randn(1, 3, 256, 256),),
142142
(torch.randn(1, 3, 256, 256),),
143-
(-5 * torch.randn(1, 3, 256, 256),),
143+
(-0.5 * torch.randn(1, 3, 256, 256),),
144+
(-2 * torch.randn(1, 3, 256, 256),),
144145
]
145146

146147
def __init__(self):

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw.glsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3535
* output at a single output location.
3636
*/
3737
void main() {
38-
const ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits.x, out_limits.y);
38+
const uint div_by_x = gl_GlobalInvocationID.x / out_limits.x;
39+
const ivec3 pos = ivec3(
40+
gl_GlobalInvocationID.x % out_limits.x,
41+
div_by_x % out_limits.y,
42+
div_by_x / out_limits.y);
3943

4044
if (any(greaterThanEqual(pos, out_limits))) {
4145
return;

backends/vulkan/runtime/graph/ops/glsl/conv2d_dw_output_tile.glsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ void main() {
4747
// since work size is calculated by x * ((y + B_Y - 1) / B_Y) * z
4848
const ivec2 out_limits_xy_scaled = (out_limits.xy + ivec2(BATCH_SIZE_X, BATCH_SIZE_Y) - 1) / ivec2(BATCH_SIZE_X, BATCH_SIZE_Y);
4949

50-
ivec3 pos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits_xy_scaled.x, out_limits_xy_scaled.y);
50+
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_xy_scaled.x;
51+
ivec3 pos = ivec3(
52+
gl_GlobalInvocationID.x % out_limits_xy_scaled.x,
53+
div_by_x % out_limits_xy_scaled.y,
54+
div_by_x / out_limits_xy_scaled.y);
5155

5256
// scale pos.xy by batch sizes, because that's the top pixel to be processed
5357
pos.x *= BATCH_SIZE_X;

backends/vulkan/runtime/graph/ops/glsl/conv2d_pw.glsl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,11 @@ void main() {
4444
const ivec2 out_limits_scaled = (out_limits.xy + TILE_SIZE - 1) / TILE_SIZE;
4545
const uint shared_mem_stride = gl_WorkGroupSize.x * gl_WorkGroupSize.y * gl_WorkGroupSize.z;
4646

47-
const ivec3 gpos = idx_to_ipos_x_wise(gl_GlobalInvocationID.x, out_limits_scaled.x, out_limits_scaled.y);
47+
const uint div_by_x = gl_GlobalInvocationID.x / out_limits_scaled.x;
48+
const ivec3 gpos = ivec3(
49+
gl_GlobalInvocationID.x % out_limits_scaled.x,
50+
div_by_x % out_limits_scaled.y,
51+
div_by_x / out_limits_scaled.y);
4852

4953
// Output position for TILE_SIZE = 2
5054
// +--------+--------+

backends/vulkan/runtime/graph/ops/glsl/indexing_utils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -223,11 +223,6 @@ ivec3 lpos_to_pos(const ivec3 lpos, const ivec4 axis_map) {
223223
return pos;
224224
}
225225

226-
ivec3 idx_to_ipos_x_wise(uint idx, int size_x, int size_y) {
227-
const uint div_by_x = idx / size_x;
228-
return ivec3(idx % size_x, div_by_x % size_y, div_by_x / size_y);
229-
}
230-
231226
#ifdef USING_BUFFER
232227
#define load_texel(buf, idx) buf[idx]
233228
#elif defined(USING_TEXTURE2D)

backends/xnnpack/test/tester/tester.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
23
# All rights reserved.
34
#
45
# This source code is licensed under the BSD-style license found in the
@@ -679,6 +680,9 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
679680
for i in range(len(model_output)):
680681
model = model_output[i]
681682
ref = ref_output[i]
683+
assert (
684+
ref.shape == model.shape
685+
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
682686
assert torch.allclose(
683687
model,
684688
ref,

install_requirements.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def python_is_compatible():
132132
# NOTE: If a newly-fetched version of the executorch repo changes the value of
133133
# NIGHTLY_VERSION, you should re-run this script to install the necessary
134134
# package versions.
135-
NIGHTLY_VERSION = "dev20241218"
135+
NIGHTLY_VERSION = "dev20250104"
136136

137137
# The pip repository that hosts nightly torch packages.
138138
TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu"

0 commit comments

Comments
 (0)