Skip to content

Commit abf1134

Browse files
committed
Update on "[Executorch][llm] Enable leveraging ring kv cache via module swap"
This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) [ghstack-poisoned]
2 parents 5ed2284 + 1013001 commit abf1134

File tree

195 files changed

+9325
-8028
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

195 files changed

+9325
-8028
lines changed

.ci/docker/ci_commit_pins/buck2.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2024-12-16
1+
2025-05-06

.github/workflows/apple.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ on:
55
branches:
66
- main
77
- release/*
8+
tags:
9+
- ciflow/trunk/*
810
pull_request:
911
paths:
1012
- .ci/scripts/setup-ios.sh

.github/workflows/build-presets.yml

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,20 @@ on:
1111
concurrency:
1212
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
1313
cancel-in-progress: true
14+
15+
jobs:
16+
apple:
17+
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
18+
strategy:
19+
matrix:
20+
preset: [macos-arm64]
21+
with:
22+
job-name: build
23+
runner: macos-latest-xlarge
24+
python-version: 3.12
25+
submodules: recursive
26+
script: |
27+
set -eux
28+
${CONDA_RUN} ./install_requirements.sh > /dev/null
29+
${CONDA_RUN} cmake --preset ${{ matrix.preset }}
30+
${CONDA_RUN} cmake --build cmake-out --parallel

.github/workflows/pull.yml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ jobs:
434434
output=$(ls -la cmake-out/test/size_test)
435435
arr=($output)
436436
size=${arr[4]}
437-
# threshold=48120 on devserver with gcc11.4
438-
# todo(lfq): update once binary size is below 50kb.
439-
threshold="47552"
437+
threshold="47560"
440438
if [[ "$size" -le "$threshold" ]]; then
441439
echo "Success $size <= $threshold"
442440
else

CMakeLists.txt

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,19 @@
4444

4545
cmake_minimum_required(VERSION 3.24)
4646
project(executorch)
47+
48+
# MARK: - Start EXECUTORCH_H12025_BUILD_MIGRATION --------------------------------------------------
49+
50+
include(${PROJECT_SOURCE_DIR}/tools/cmake/common/preset.cmake)
51+
52+
load_build_preset()
53+
include(${PROJECT_SOURCE_DIR}/tools/cmake/preset/default.cmake)
54+
55+
# Print all the configs that were called with announce_configured_options.
56+
print_configured_options()
57+
58+
# MARK: - End EXECUTORCH_H12025_BUILD_MIGRATION ----------------------------------------------------
59+
4760
include(tools/cmake/Utils.cmake)
4861
include(CMakeDependentOption)
4962

@@ -96,9 +109,6 @@ set(EXECUTORCH_PAL_DEFAULT
96109
"Which PAL default implementation to use: one of {posix, minimal}"
97110
)
98111

99-
option(EXECUTORCH_ENABLE_LOGGING "Build with ET_LOG_ENABLED"
100-
${_default_release_disabled_options}
101-
)
102112
if(NOT EXECUTORCH_ENABLE_LOGGING)
103113
# Avoid pulling in the logging strings, which can be large. Note that this
104114
# will set the compiler flag for all targets in this directory, and for all
@@ -170,8 +180,6 @@ option(EXECUTORCH_BUILD_ARM_BAREMETAL
170180
"Build the Arm Baremetal flow for Cortex-M and Ethos-U" OFF
171181
)
172182

173-
option(EXECUTORCH_BUILD_COREML "Build the Core ML backend" OFF)
174-
175183
option(EXECUTORCH_BUILD_KERNELS_CUSTOM "Build the custom kernels" OFF)
176184

177185
option(EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT "Build the custom ops lib for AOT"

CMakePresets.json

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
{
2+
"version": 10,
3+
"cmakeMinimumRequired": {
4+
"major": 3,
5+
"minor": 31,
6+
"patch": 0
7+
},
8+
"$comment": "On-device AI across mobile, embedded and edge for PyTorch.",
9+
"configurePresets": [
10+
{
11+
"name": "common",
12+
"hidden": true,
13+
"binaryDir": "${sourceDir}/cmake-out",
14+
"generator": "Unix Makefiles"
15+
},
16+
{
17+
"name": "macos-arm64",
18+
"inherits": ["common"],
19+
"generator": "Xcode",
20+
"cacheVariables": {
21+
"CMAKE_TOOLCHAIN_FILE": "${sourceDir}/third-party/ios-cmake/ios.toolchain.cmake",
22+
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/macos-arm64.cmake",
23+
"PLATFORM": "MAC_ARM64",
24+
"DEPLOYMENT_TARGET": "10.15"
25+
},
26+
"condition": {
27+
"lhs": "${hostSystemName}",
28+
"type": "equals",
29+
"rhs": "Darwin"
30+
}
31+
}
32+
]
33+
}

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .convert_squeezes_to_view import ConvertSqueezesToViewPass # noqa
2020
from .convert_to_clamp import ConvertToClampPass # noqa
2121
from .decompose_batchnorm_pass import DecomposeBatchNormPass # noqa
22+
from .decompose_cosine_similarity_pass import DecomposeCosineSimilarityPass # noqa
2223
from .decompose_div_pass import DecomposeDivPass # noqa
2324
from .decompose_gelu_pass import DecomposeGeluPass # noqa
2425
from .decompose_layernorm_pass import DecomposeLayerNormPass # noqa

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import itertools
10-
9+
import operator
1110
from typing import List
1211

1312
import torch
@@ -22,7 +21,7 @@
2221

2322
class AnnotateDecomposedMatmulPass(ExportPass):
2423
"""
25-
torch.matmul can be decomposed in many ways, for instance:
24+
torch.matmul and it's equivalent operator @ can be decomposed in many ways, for instance:
2625
dq -> matmul -> q can become
2726
dq -> repeat -> view -> bmm -> view -> dq which makes quantization folding
2827
difficult. This helper function find all matmul partitions and annotate its
@@ -50,6 +49,7 @@ def call(self, graph_module: GraphModule) -> PassResult:
5049
graph_module.graph,
5150
[
5251
torch.matmul,
52+
operator.matmul,
5353
],
5454
None,
5555
)

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
ConvertSqueezesToViewPass,
2525
ConvertToClampPass,
2626
DecomposeBatchNormPass,
27+
DecomposeCosineSimilarityPass,
2728
DecomposeDivPass,
2829
DecomposeGeluPass,
2930
DecomposeLayerNormPass,
@@ -205,6 +206,7 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
205206
self.add_pass(DecomposeVarPass())
206207
self.add_pass(DecomposeMeanDimPass())
207208
self.add_pass(DecomposeNotEqualPass())
209+
self.add_pass(DecomposeCosineSimilarityPass())
208210
self.add_pass(DecomposeDivPass())
209211
self.add_pass(DecomposeLeakyReLUPass())
210212
self.add_pass(DecomposeSqrtPass())
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
import torch
7+
from executorch.exir.pass_base import ExportPass
8+
9+
torch_cosine_similarity = (torch.ops.aten.cosine_similarity.default,)
10+
11+
12+
class DecomposeCosineSimilarityPass(ExportPass):
13+
"""
14+
Decomposition of aten.cosine_similarity:
15+
16+
dot = sum(mul(x1, x2), dims, keepdim=False)
17+
norm = pow( sum(mul(x, x), dims, keepdim=False), 0.5 )
18+
eps = full( (), eps_scalar )
19+
n1c = max(norm1, eps)
20+
n2c = max(norm2, eps)
21+
denom = mul(n1c, n2c)
22+
out = div(dot, denom)
23+
"""
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op not in torch_cosine_similarity:
27+
return super().call_operator(op, args, kwargs, meta)
28+
29+
x1, x2 = args[0], args[1]
30+
dim = kwargs.get("dim", 1)
31+
eps = kwargs.get("eps", 1e-8)
32+
dims = [dim] if isinstance(dim, int) else list(dim)
33+
34+
# 1) dot
35+
prod = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x2), {}, meta)
36+
dot = super().call_operator(
37+
torch.ops.aten.sum.dim_IntList, (prod, dims, False), {}, meta
38+
)
39+
40+
# 2a) norm1 = pow(sum(x1*x1), 0.5)
41+
x1_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x1, x1), {}, meta)
42+
s1 = super().call_operator(
43+
torch.ops.aten.sum.dim_IntList, (x1_sq, dims, False), {}, meta
44+
)
45+
norm1 = super().call_operator(
46+
torch.ops.aten.pow.Tensor_Scalar, (s1, 0.5), {}, meta
47+
)
48+
49+
# 2b) norm2 = pow(sum(x2*x2), 0.5)
50+
x2_sq = super().call_operator(torch.ops.aten.mul.Tensor, (x2, x2), {}, meta)
51+
s2 = super().call_operator(
52+
torch.ops.aten.sum.dim_IntList, (x2_sq, dims, False), {}, meta
53+
)
54+
norm2 = super().call_operator(
55+
torch.ops.aten.pow.Tensor_Scalar, (s2, 0.5), {}, meta
56+
)
57+
58+
# 3) eps scalar - we need to broadcast ourselves as TOSA dont do this for scalar
59+
eps_t = super().call_operator(
60+
torch.ops.aten.full_like.default, (norm1, eps), {}, meta
61+
)
62+
63+
# 4) clamp to avoid zero division
64+
n1c = super().call_operator(
65+
torch.ops.aten.maximum.default, (norm1, eps_t), {}, meta
66+
)
67+
n2c = super().call_operator(
68+
torch.ops.aten.maximum.default, (norm2, eps_t), {}, meta
69+
)
70+
71+
# 5) denom and divide
72+
denom = super().call_operator(torch.ops.aten.mul.Tensor, (n1c, n2c), {}, meta)
73+
out = super().call_operator(torch.ops.aten.div.Tensor, (dot, denom), {}, meta)
74+
75+
return out

backends/arm/operator_support/pool_2d_support.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,11 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5454
kernel = cast(tuple[int, int], node.args[1])
5555
stride = cast(tuple[int, int], node.args[2])
5656
if len(node.args) > 3:
57+
padding = cast(tuple[int, int], node.args[3])
5758
# Padding case
58-
if not all(1 <= k <= 8 for k in kernel):
59+
if not all(1 <= k <= 8 for k in kernel) and not all(
60+
v == 0 for v in padding
61+
):
5962
self.reporter.report_reject(
6063
node, f"Avgpool2d with padding needs kernel dims < 8, got {kernel}"
6164
)

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ def _is_matmul_node_supported(
335335
graph_module.graph,
336336
[
337337
torch.matmul,
338+
operator.matmul,
338339
],
339340
None,
340341
)
@@ -385,7 +386,7 @@ def is_node_supported(
385386
):
386387
source_fn_stack: tuple[typing.Any] = node.meta.get("source_fn_stack", [])
387388
if len(source_fn_stack) > 0:
388-
if source_fn_stack[-1][1] in (torch.matmul,):
389+
if source_fn_stack[-1][1] in (torch.matmul, operator.matmul):
389390
return self._is_matmul_node_supported(submodules, node)
390391

391392
elif node.target in (exir_ops.edge.aten.max_pool2d_with_indices.default,):

backends/arm/operators/TARGETS

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,19 @@ python_library(
1010
],
1111
)
1212

13+
python_library(
14+
name = "operator_validation_utils",
15+
srcs = ["operator_validation_utils.py"],
16+
)
17+
1318
python_library(
1419
name = "ops",
1520
srcs = glob(["op_*.py", "ops_*.py"]),
1621
deps = [
1722
"fbsource//third-party/tosa_tools/v0.80/serialization_lib/python/tosa:tosa",
1823
"fbsource//third-party/tosa_tools/v1.00/serialization_lib/python/tosa:tosa",
1924
":node_visitor",
25+
":operator_validation_utils",
2026
"//executorch/backends/arm:tosa_mapping",
2127
"//executorch/backends/arm:tosa_quant_utils",
2228
"//executorch/backends/arm:tosa_utils",

backends/arm/operators/op_max_pool2d.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@
2323
from executorch.backends.arm.tosa_specification import TosaSpecification
2424

2525

26+
# Similarly to Conv2d, the TOSA spec requires that following is exactly divisible:
27+
# `(input + 2 * pad - kernel_size) / stride`
28+
# PyTorch however, does not require this, so as needed, we must adjust the padding.
29+
def adjust_pad_if_needed(
30+
input_size: int, kernel_size: int, stride: int, pad: int
31+
) -> int:
32+
if pad == 0:
33+
return pad
34+
35+
mod_remainder = (input_size + 2 * pad - kernel_size) % stride
36+
37+
# No need to adjust
38+
if mod_remainder == 0:
39+
return pad
40+
41+
return pad - mod_remainder
42+
43+
2644
@register_node_visitor
2745
class MaxPool2dVisitor_0_80(NodeVisitor):
2846
target = "aten.max_pool2d.default"
@@ -61,6 +79,20 @@ def define_node(
6179
except IndexError:
6280
pad_size_list = [0, 0, 0, 0]
6381

82+
# Adjust the padding as necessary
83+
pad_size_list[1] = adjust_pad_if_needed(
84+
input_tensor.shape[2],
85+
kernel_size[0],
86+
stride[0],
87+
pad_size_list[1],
88+
)
89+
pad_size_list[3] = adjust_pad_if_needed(
90+
input_tensor.shape[3],
91+
kernel_size[1],
92+
stride[1],
93+
pad_size_list[3],
94+
)
95+
6496
accumulator_type = output.dtype
6597

6698
# Initilize zero point to zero.
@@ -131,6 +163,20 @@ def define_node(
131163
except IndexError:
132164
pad_size_list = [0, 0, 0, 0]
133165

166+
# Adjust the padding as necessary
167+
pad_size_list[1] = adjust_pad_if_needed(
168+
input_tensor.shape[2],
169+
kernel_size[0],
170+
stride[0],
171+
pad_size_list[1],
172+
)
173+
pad_size_list[3] = adjust_pad_if_needed(
174+
input_tensor.shape[3],
175+
kernel_size[1],
176+
stride[1],
177+
pad_size_list[3],
178+
)
179+
134180
attr = ts.TosaSerializerAttribute()
135181
attr.MaxPool2dAttribute(
136182
kernel=kernel_size, stride=stride, pad=pad_size_list, nan_mode=1

0 commit comments

Comments
 (0)