Skip to content

Commit 3375f85

Browse files
committed
Update on "[Executorch][llm] Enable leveraging ring kv cache via module swap"
This allows us to make some of the attention modules to use sliding window kv cache. Will help enable models like gemma3. Differential Revision: [D73891426](https://our.internmc.facebook.com/intern/diff/D73891426/) [ghstack-poisoned]
2 parents e47dfd9 + eb677e5 commit 3375f85

File tree

175 files changed

+2347
-760
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

175 files changed

+2347
-760
lines changed

.github/workflows/_link_check.yml

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,35 +7,51 @@ on:
77

88
jobs:
99
lint-urls:
10+
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-url-lint') }}
1011
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
1112
with:
1213
runner: linux.2xlarge
1314
docker-image: executorch-ubuntu-22.04-linter
14-
submodules: 'none'
15+
submodules: false
1516
fetch-depth: 0
1617
ref: ${{ inputs.ref }}
17-
timeout: 90
18+
timeout: 120
1819
script: |
1920
./scripts/lint_urls.sh $(
20-
[ "${{ github.event_name }}" = "pull_request" ] \
21-
&& git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
22-
|| [ "${{ github.event_name }}" = "push" ] \
23-
&& git diff --name-only ${{ github.event.before }} ${{ github.sha }}
24-
)
21+
{ [ "${{ github.event_name }}" = "pull_request" ] \
22+
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
23+
|| \
24+
{ [ "${{ github.event_name }}" = "push" ] \
25+
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
26+
) || {
27+
echo
28+
echo "URL lint failed."
29+
echo "If this is a transient outage, you can bypass it by adding the \`skip-url-lint\` label to your PR."
30+
echo "Or add \`@lint-ignore\` somewhere on the same line as the URL you want to skip checking."
31+
exit 1
32+
}
2533
2634
lint-xrefs:
35+
if: ${{ github.event_name != 'pull_request' || !contains(github.event.pull_request.labels.*.name, 'skip-xref-lint') }}
2736
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
2837
with:
2938
runner: linux.2xlarge
3039
docker-image: executorch-ubuntu-22.04-linter
31-
submodules: 'none'
40+
submodules: false
3241
fetch-depth: 0
3342
ref: ${{ inputs.ref }}
34-
timeout: 90
43+
timeout: 60
3544
script: |
3645
./scripts/lint_xrefs.sh $(
37-
[ "${{ github.event_name }}" = "pull_request" ] \
38-
&& git diff --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
39-
|| [ "${{ github.event_name }}" = "push" ] \
40-
&& git diff --name-only ${{ github.event.before }} ${{ github.sha }}
41-
)
46+
{ [ "${{ github.event_name }}" = "pull_request" ] \
47+
&& git diff --name-only "${{ github.event.pull_request.base.sha }}...${{ github.event.pull_request.head.sha }}"; } \
48+
|| \
49+
{ [ "${{ github.event_name }}" = "push" ] \
50+
&& git diff --name-only "${{ github.event.before }}...${{ github.sha }}"; }
51+
) || {
52+
echo
53+
echo "Xref lint failed."
54+
echo "If this is a transient outage, you can bypass it by adding the \`skip-xref-lint\` label to your PR."
55+
echo "Or add \`@lint-ignore\` somewhere on the same line as the reference you want to skip checking."
56+
exit 1
57+
}

.github/workflows/build-presets.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: Build Presets
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
- release/*
9+
workflow_dispatch:
10+
11+
concurrency:
12+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
13+
cancel-in-progress: true

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ endif()
608608
# any backends.
609609
#
610610
add_library(executorch ${_executorch__srcs})
611-
target_link_libraries(executorch PUBLIC executorch_core)
611+
target_link_libraries(executorch PRIVATE executorch_core)
612612
target_include_directories(executorch PUBLIC ${_common_include_directories})
613613
target_compile_definitions(executorch PUBLIC C10_USING_CUSTOM_GENERATED_MACROS)
614614
target_compile_options(executorch PUBLIC ${_common_compile_options})

backends/arm/_passes/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ python_library(
1111
"//executorch/backends/xnnpack/_passes:xnnpack_passes",
1212
"//executorch/exir:lib",
1313
"//executorch/backends/transforms:utils",
14+
"//executorch/backends/transforms:decompose_sdpa",
1415
],
1516
)

backends/arm/_passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,5 @@
5757
from .size_adjust_conv2d_pass import SizeAdjustConv2DPass # noqa
5858
from .unsqueeze_before_repeat_pass import UnsqueezeBeforeRepeatPass # noqa
5959
from .unsqueeze_scalar_placeholders_pass import UnsqueezeScalarPlaceholdersPass # noqa
60+
from .replace_inf_values_pass import ReplaceInfValues # noqa # usort: skip
6061
from .arm_pass_manager import ArmPassManager # noqa # usort: skip

backends/arm/_passes/annotate_decomposed_matmul.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,14 @@ def call(self, graph_module: GraphModule) -> PassResult:
7070
if quantized_input:
7171
matmul_args = matmul_node.all_input_nodes
7272
for node in matmul_args:
73+
# Find the dq-node connected to this mm/bmm arg
7374
input_node = self._match_partition_to_node(
7475
node, partition.input_nodes
7576
)
76-
77-
# Remove partition input dq-node
78-
input_node.replace_all_uses_with(input_node.all_input_nodes[0])
79-
graph_module.graph.erase_node(input_node)
8077
input_node_qargs = QuantArgs.from_operator(
8178
input_node.target, input_node.args
8279
)
83-
80+
# Insert new dq-node just before the mm/bmm with input_node's qparams
8481
with graph_module.graph.inserting_before(matmul_node):
8582
# Create new dq-node before matmul
8683
dq_node = create_node(
@@ -90,6 +87,13 @@ def call(self, graph_module: GraphModule) -> PassResult:
9087
dq_node.args = (node, *input_node_qargs)
9188
matmul_node.replace_input_with(node, dq_node)
9289

90+
for partition_input in partition.input_nodes:
91+
# Remove partition input dq-node
92+
partition_input.replace_all_uses_with(
93+
partition_input.all_input_nodes[0]
94+
)
95+
graph_module.graph.erase_node(partition_input)
96+
9397
partition_output = list(partition.output_nodes[0].users)[0]
9498
quantized_output = partition_output.target == q_op
9599
if quantized_output:

backends/arm/_passes/arm_pass_manager.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
MatchWhereSelfDtypePass,
5050
QuantizeOperatorArguments,
5151
RemoveClonePass,
52+
ReplaceInfValues,
5253
ReplaceScalarWithTensorArgPassTOSABI,
5354
ReplaceScalarWithTensorArgPassTOSAMI,
5455
RetraceFoldedDtypesPass,
@@ -216,4 +217,5 @@ def transform_for_annotation_pipeline(self, graph_module: GraphModule):
216217
self.add_pass(DecomposeSoftmaxPass())
217218

218219
self.add_pass(ConvertMinMaxPass())
220+
self.add_pass(ReplaceInfValues())
219221
return self._transform(graph_module)

backends/arm/_passes/convert_split_to_slice.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
# Copyright 2024 Arm Limited and/or its affiliates.
2-
# All rights reserved.
1+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
65

76
# pyre-unsafe
87

98
import torch.fx
10-
from executorch.backends.arm._passes.arm_pass_utils import create_node
11-
from executorch.backends.arm.tosa_mapping import extract_tensor_meta
9+
from executorch.backends.arm._passes.arm_pass_utils import (
10+
create_node,
11+
get_first_fake_tensor,
12+
)
1213
from executorch.exir.dialects._ops import ops as exir_ops
1314
from executorch.exir.pass_base import ExportPass, PassResult
1415

@@ -34,7 +35,7 @@ def call(self, graph_module: torch.fx.GraphModule):
3435
split_node = node
3536
input_node = split_node.all_input_nodes[0]
3637
output_nodes = split_node.users.copy()
37-
_, shape, _ = extract_tensor_meta(input_node.meta)
38+
shape = get_first_fake_tensor(input_node).shape
3839
rank = len(shape)
3940
split_lengths = split_node.args[1]
4041
dim = split_node.args[2] if len(split_node.args) > 2 else 0
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Qualcomm Innovation Center, Inc.
2+
# Copyright 2025 Arm Limited and/or its affiliates.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# This pass is based on backends/qualcomm/_passes/replace_inf_values.py
8+
# with some modification to replaced inf values.
9+
10+
import torch
11+
from executorch.exir.pass_base import ExportPass, PassResult
12+
13+
14+
class ReplaceInfValues(ExportPass):
15+
"""
16+
Due to limitation in Quantizer, we need to change inf/-inf to more quantizable values.
17+
"""
18+
19+
def __init__(self):
20+
super(ReplaceInfValues, self).__init__()
21+
22+
def call(self, graph_module: torch.fx.GraphModule):
23+
modified = False
24+
for buf_name, tensor in graph_module.named_buffers():
25+
if tensor.is_floating_point():
26+
modified = True
27+
# 255 here is mainly for attention_mask in Llama for reasonable quant scale
28+
tensor[tensor == float("inf")] = 255
29+
tensor[tensor == float("-inf")] = -255
30+
setattr(graph_module, buf_name, tensor)
31+
32+
for node in graph_module.graph.nodes:
33+
arg_list = list(node.args)
34+
for index, arg in enumerate(arg_list):
35+
if arg == float("-inf"):
36+
modified = True
37+
arg_list[index] = -255
38+
elif arg == float("inf"):
39+
modified = True
40+
arg_list[index] = +255
41+
node.args = tuple(arg_list)
42+
43+
if modified:
44+
graph_module.recompile()
45+
return PassResult(graph_module, modified)

backends/arm/operator_support/slice_copy_support.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
SupportedTOSAOperatorCheck,
1313
)
1414
from executorch.backends.arm.tosa_specification import TosaSpecification
15-
from executorch.backends.arm.tosa_utils import getNodeArgs
1615
from executorch.exir.dialects._ops import ops as exir_ops
1716

1817
logger = logging.getLogger(__name__)
@@ -33,8 +32,8 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) ->
3332
if tosa_spec not in self.tosa_specs:
3433
return False
3534

36-
inputs = getNodeArgs(node)
37-
if len(inputs) == 5 and (step := inputs[4].number) != 1:
35+
args = node.args
36+
if len(args) == 5 and (step := args[4]) != 1:
3837
logging.warning(f"{node.target} with step size of {step} not supported.")
3938
return False
4039
return True

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ def is_node_supported(
194194
exir_ops.edge.aten.mul.Tensor,
195195
exir_ops.edge.aten.ne.Tensor,
196196
exir_ops.edge.aten.ne.Scalar,
197+
exir_ops.edge.aten.neg.default,
197198
exir_ops.edge.aten.add.Scalar,
198199
exir_ops.edge.aten.sub.Scalar,
199200
exir_ops.edge.aten.mul.Scalar,
@@ -311,6 +312,7 @@ class CheckProperQuantization(OperatorSupportBase):
311312
exir_ops.edge.aten.max_pool2d_with_indices.default,
312313
exir_ops.edge.aten.mm.default,
313314
exir_ops.edge.aten.mul.Tensor,
315+
exir_ops.edge.aten.neg.default,
314316
exir_ops.edge.aten.relu.default,
315317
exir_ops.edge.aten.sub.Tensor,
316318
exir_ops.edge.aten.upsample_bilinear2d.vec,

backends/arm/operators/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
op_maximum,
3232
op_minimum,
3333
op_mul,
34+
op_neg,
3435
op_permute,
3536
op_pow,
3637
op_reciprocal,

backends/arm/operators/op_abs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
NodeVisitor,
1414
register_node_visitor,
1515
)
16+
from executorch.backends.arm.operators.operator_validation_utils import (
17+
validate_num_inputs,
18+
)
1619
from executorch.backends.arm.tosa_mapping import TosaArg
1720
from executorch.backends.arm.tosa_specification import TosaSpecification
1821
from torch.fx import Node
@@ -39,6 +42,7 @@ def define_node(
3942

4043
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4144

45+
validate_num_inputs(self.target, inputs, 1)
4246
# Specification (0.80) states that input and output types
4347
# should all be the same
4448
if not (inputs[0].dtype == output.dtype):
@@ -105,6 +109,7 @@ def define_node(
105109

106110
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
107111

112+
validate_num_inputs(self.target, inputs, 1)
108113
# Specification (0.80) states that input and output types
109114
# should all be the same
110115
if not (inputs[0].dtype == output.dtype):
@@ -157,6 +162,8 @@ def define_node(
157162

158163
import serializer.tosa_serializer as ts # type: ignore
159164

165+
validate_num_inputs(self.target, inputs, 1)
166+
160167
# Specification (1.0) states that input and output types
161168
# should all be the same
162169
if not (inputs[0].dtype == output.dtype):
@@ -224,6 +231,8 @@ def define_node(
224231

225232
import serializer.tosa_serializer as ts # type: ignore
226233

234+
validate_num_inputs(self.target, inputs, 1)
235+
227236
# Specification (1.0) states that input and output types
228237
# should all be the same
229238
if not (inputs[0].dtype == output.dtype):

backends/arm/operators/op_add.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
NodeVisitor,
1515
register_node_visitor,
1616
)
17+
from executorch.backends.arm.operators.operator_validation_utils import (
18+
validate_num_inputs,
19+
)
1720
from executorch.backends.arm.tosa_mapping import TosaArg
1821
from executorch.backends.arm.tosa_specification import TosaSpecification
1922
from torch.fx import Node
@@ -40,6 +43,7 @@ def define_node(
4043

4144
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
4245

46+
validate_num_inputs(self.target, inputs, 2)
4347
# Specification (0.80) states that input and output types
4448
# should all be the same
4549
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -118,6 +122,7 @@ def define_node(
118122

119123
import tosa_tools.v0_80.serializer.tosa_serializer as ts # type: ignore
120124

125+
validate_num_inputs(self.target, inputs, 2)
121126
# Specification (0.80) states that input and output types
122127
# should all be the same
123128
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -169,6 +174,8 @@ def define_node(
169174

170175
import serializer.tosa_serializer as ts # type: ignore
171176

177+
validate_num_inputs(self.target, inputs, 2)
178+
172179
# Specification (1.0) states that input and output types
173180
# should all be the same
174181
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:
@@ -237,6 +244,8 @@ def define_node(
237244

238245
import serializer.tosa_serializer as ts # type: ignore
239246

247+
validate_num_inputs(self.target, inputs, 2)
248+
240249
# Specification (1.0) states that input and output types
241250
# should all be the same
242251
if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype:

backends/arm/operators/op_amax.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
NodeVisitor,
1010
register_node_visitor,
1111
)
12+
from executorch.backends.arm.operators.operator_validation_utils import (
13+
validate_num_inputs,
14+
)
1215
from executorch.backends.arm.tosa_mapping import TosaArg
1316
from torch.fx import Node
1417

@@ -31,6 +34,8 @@ def define_node(
3134
) -> None:
3235
import tosa_tools.v0_80.serializer.tosa_serializer as ts
3336

37+
validate_num_inputs(self.target, inputs, 3)
38+
3439
input = inputs[0]
3540
dim = inputs[1].number
3641

@@ -71,6 +76,8 @@ def define_node(
7176
) -> None:
7277
import serializer.tosa_serializer as ts
7378

79+
validate_num_inputs(self.target, inputs, 3)
80+
7481
input = inputs[0]
7582
dim = inputs[1].number
7683

0 commit comments

Comments
 (0)