Skip to content

Commit 9666ee8

Browse files
authored
Fix executorch kv cache incompatibility with to_executorch lowering (#7279)
1 parent 0c4053e commit 9666ee8

File tree

8 files changed

+365
-31
lines changed

8 files changed

+365
-31
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import torch
2424

2525
from executorch.devtools.etrecord import generate_etrecord
26+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
2627

2728
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
2829

@@ -775,6 +776,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
775776
for partitioner in partitioners:
776777
logging.info(f"--> {partitioner.__class__.__name__}")
777778

779+
additional_passes = []
780+
if args.model in TORCHTUNE_DEFINED_MODELS:
781+
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
778782
if args.generate_etrecord:
779783
if not builder_exported_to_edge.edge_manager:
780784
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -789,7 +793,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
789793
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
790794
canonicalize_program(builder.edge_manager.exported_program())
791795

792-
builder = builder.to_executorch()
796+
builder = builder.to_executorch(
797+
passes=additional_passes,
798+
)
793799

794800
# Generate ETRecord
795801
if edge_manager_copy:
@@ -807,7 +813,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
807813
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
808814
canonicalize_program(builder.edge_manager.exported_program())
809815

810-
builder = builder.to_executorch()
816+
builder = builder.to_executorch(passes=additional_passes)
811817

812818
if args.profile_memory:
813819
generate_memory_trace(builder.export_program, "memory_profile.json")

examples/models/llama3_2_vision/runner/native.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
TorchTuneLlamaRunner,
1919
)
2020

21-
from executorch.extension.pybindings.portable_lib import _load_for_executorch
21+
from executorch.extension.pybindings.portable_lib import (
22+
_load_for_executorch_from_buffer,
23+
)
2224

2325
# Load custom ops and quantized ops.
2426
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
2527

2628
# Note: import this after portable_lib
27-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
29+
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
2830
from executorch.kernels import quantized # noqa
2931

3032

@@ -43,7 +45,17 @@ def __init__(self, args):
4345
use_kv_cache=args.kv_cache,
4446
vocab_size=params["vocab_size"],
4547
)
46-
self.model = _load_for_executorch(args.pte)
48+
# Save the loaded model bytes to prevent data from going out of
49+
# scope after the `with` and getting cleaned up by Python's
50+
# garbage collector.
51+
self.model_bytes = None
52+
with open(args.pte, "rb") as f:
53+
self.model_bytes = f.read()
54+
# Need to use _load_for_executorch_from_buffer instead of
55+
# _load_for_executorch because the latter uses MmapDataLoader,
56+
# which doesn't have load_into() implemented, which is needed
57+
# for loading initialized mutable buffers.
58+
self.model = _load_for_executorch_from_buffer(self.model_bytes)
4759
self.use_kv_cache = args.kv_cache
4860

4961
def forward(

exir/emit/_emitter.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder(
15751575
warnings.warn(
15761576
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
15771577
"buffers that are mutated in the graph have a meaningless initial state, "
1578-
"only the shape and dtype will be serialized.",
1578+
"only the shape and dtype will be serialized, unless a pass which sets "
1579+
'meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.',
15791580
UserWarning,
15801581
stacklevel=1,
15811582
)
@@ -1602,6 +1603,7 @@ def placeholder(
16021603
"""
16031604
spec = self.node.meta["spec"]
16041605
constant_tag = self.node.meta.get("constant_tag", None)
1606+
initialize_buffer = self.node.meta.get("et_init_buffer", None)
16051607
is_user_input = True
16061608

16071609
if isinstance(target, str) and isinstance(spec, TensorSpec):
@@ -1655,7 +1657,10 @@ def placeholder(
16551657
spec.storage = real_tensor.untyped_storage()
16561658

16571659
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1658-
spec.const = not (is_user_input or is_mutable_buffer)
1660+
if initialize_buffer and is_mutable_buffer:
1661+
spec.const = True
1662+
else:
1663+
spec.const = not (is_user_input or is_mutable_buffer)
16591664

16601665
evalue = (
16611666
self._tensor_spec_to_evalue(spec, constant_tag)

exir/emit/test/test_emit.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import typing
1010
import unittest
1111
from contextlib import contextmanager
12+
from copy import deepcopy
1213
from typing import List, Optional, Tuple
1314

1415
import executorch.exir as exir
@@ -31,6 +32,7 @@
3132
from executorch.exir.error import InternalError
3233
from executorch.exir.passes import MemoryPlanningPass
3334
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
35+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
3436
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
3537
from executorch.exir.print_program import pretty_print, print_program # noqa
3638
from executorch.exir.schema import (
@@ -56,6 +58,7 @@
5658
from executorch.extension.pybindings.portable_lib import (
5759
_load_for_executorch_from_buffer,
5860
)
61+
from executorch.runtime import Runtime
5962

6063
from functorch.experimental import control_flow
6164
from torch import nn
@@ -243,6 +246,56 @@ def forward(self, x):
243246
)
244247
self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)
245248

249+
def test_initialized_mutable_buffer(self):
250+
"""Test that mutable buffers can hold meaningful initialized state."""
251+
252+
class TestModule(torch.nn.Module):
253+
def __init__(self):
254+
super().__init__()
255+
# Mutable buffer with non-empty initial state.
256+
self.register_buffer("cache_pos", torch.arange(0, 10))
257+
258+
def forward(self, x):
259+
self.cache_pos.add_(1)
260+
return self.cache_pos
261+
262+
m = TestModule()
263+
example_inputs = (torch.ones(10),)
264+
ep = torch.export.export(m, example_inputs)
265+
edge = to_edge(
266+
ep,
267+
compile_config=EdgeCompileConfig(
268+
_check_ir_validity=False,
269+
),
270+
)
271+
272+
# Save a copy of the edge program since to_executorch is
273+
# stateful to some degree.
274+
edge_copy = deepcopy(edge)
275+
et_config = ExecutorchBackendConfig(
276+
passes=[InitializedMutableBufferPass(["cache_pos"])],
277+
)
278+
et_program_init_pass = edge.to_executorch(config=et_config)
279+
et_program_regular = edge_copy.to_executorch()
280+
281+
runtime = Runtime.get()
282+
program_init_pass = runtime.load_program(et_program_init_pass.buffer)
283+
method_init_pass = program_init_pass.load_method("forward")
284+
285+
program_regular = runtime.load_program(et_program_regular.buffer)
286+
method_regular = program_regular.load_method("forward")
287+
288+
# Test that the mutable buffer is initialized.
289+
torch.allclose(
290+
method_init_pass.execute((example_inputs))[0], torch.arange(1, 11)
291+
)
292+
# Test that the mutable buffer is uninitialized and starts with default zeros,
293+
# we test equality with torch.ones because of the mutation += 1 in the model forward.
294+
torch.allclose(
295+
method_regular.execute((example_inputs))[0],
296+
torch.ones(10, dtype=torch.int64),
297+
)
298+
246299
def test_int_list_input(self):
247300
class M(torch.nn.Module):
248301
def forward(self, x, y, z):

exir/passes/init_mutable_pass.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
from typing import List
9+
10+
from executorch.exir.pass_base import ExportPass
11+
12+
13+
class InitializedMutableBufferPass(ExportPass):
14+
"""
15+
If a buffer has a name that within a specified list, set meta["et_init_buffer"]
16+
to True, which provides the mutable buffer with an initialized state.
17+
18+
As an example, a module with `self.register_buffer("cache_pos", torch.arange(10))`
19+
when patterns = ["cache_pos"] would have its initial state set instead of being
20+
left uninitialized by default.
21+
"""
22+
23+
def __init__(self, patterns: List[str]) -> None:
24+
super().__init__()
25+
self.patterns = patterns
26+
27+
def placeholder(self, name: str, arg, meta):
28+
for pattern in self.patterns:
29+
if pattern in name:
30+
meta["et_init_buffer"] = True
31+
32+
return super().placeholder(name, arg, meta)

extension/llm/export/builder.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from executorch.exir.backend.utils import format_delegated_graph
2828
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
2929

30+
from executorch.exir.pass_manager import PassType
3031
from executorch.exir.passes import MemoryPlanningPass
3132
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
3233
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
@@ -415,21 +416,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
415416

416417
return self
417418

418-
def to_executorch(self) -> "LLMEdgeManager":
419+
def to_executorch(
420+
self, passes: Optional[List[PassType]] = None
421+
) -> "LLMEdgeManager":
419422
"""
420423
Lower the model to executorch and get an ExecutorchProgram.
421424
"""
422425
assert self.edge_manager, "Need to run export_to_edge() first"
426+
to_executorch_passes = [
427+
# If there are Linear operations left in the graph, let's execute
428+
# them with the optimized op_linear rather than materializing a
429+
# transpose followed by a regular op_mm.
430+
ConvertToLinearPass(),
431+
QuantFusionPass(),
432+
]
433+
if passes:
434+
to_executorch_passes.extend(passes)
435+
423436
self.export_program = self.edge_manager.to_executorch(
424437
ExecutorchBackendConfig(
425438
extract_delegate_segments=True,
426-
passes=[
427-
# If there are Linear operations left in the graph, let's execute
428-
# them with the optimized op_linear rather than materializing a
429-
# transpose followed by a regular op_mm.
430-
ConvertToLinearPass(),
431-
QuantFusionPass(),
432-
],
439+
passes=to_executorch_passes,
433440
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
434441
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
435442
)

extension/llm/modules/test/test_attention.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from executorch.exir import EdgeCompileConfig, to_edge
1313

14+
from executorch.exir.capture._config import ExecutorchBackendConfig
15+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
1416
from executorch.extension.llm.modules.attention import (
1517
MultiHeadAttention as ETMultiHeadAttention,
1618
)
@@ -114,7 +116,7 @@ def test_attention_eager(self):
114116
et_res = self.et_mha(self.x, self.x) # Self attention.
115117
tt_res = self.tt_mha(self.x, self.x) # Self attention.
116118

117-
self.assertTrue(torch.allclose(et_res, tt_res))
119+
assert_close(et_res, tt_res)
118120
self.et_mha.reset_cache()
119121
self.tt_mha.reset_cache()
120122

@@ -125,7 +127,7 @@ def test_attention_eager(self):
125127
self.x, self.x, input_pos=self.input_pos
126128
) # Self attention with input pos.
127129

128-
self.assertTrue(torch.allclose(et_res, tt_res))
130+
assert_close(et_res, tt_res)
129131

130132
# test kv cache read. Input pos can be [10, 11, ..., 19]
131133
next_input_pos = torch.arange(10, 20).unsqueeze(0)
@@ -187,9 +189,8 @@ def test_attention_aoti(self):
187189

188190
def test_attention_executorch(self):
189191
# Self attention.
190-
# TODO: Fix kv cache
191-
# self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192-
# self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
193+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
193194

194195
with torch.no_grad():
195196
et_mha_ep = torch.export.export(
@@ -202,9 +203,15 @@ def test_attention_executorch(self):
202203
et_program = to_edge(
203204
et_mha_ep,
204205
compile_config=EdgeCompileConfig(
205-
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
206+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
207+
_check_ir_validity=False,
206208
),
207-
).to_executorch()
209+
).to_executorch(
210+
config=ExecutorchBackendConfig(
211+
passes=[InitializedMutableBufferPass(["cache_pos"])],
212+
)
213+
)
214+
208215
runtime = Runtime.get()
209216
program = runtime.load_program(et_program.buffer)
210217
method = program.load_method("forward")
@@ -219,28 +226,23 @@ def test_attention_torch_cond_eager(self):
219226
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
220227
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
221228

222-
# mask
223229
mask = self.causal_mask[self.input_pos, :]
224-
# First run
230+
# First run.
225231
et_res = self.et_mha(
226232
self.x, self.x, mask=mask, input_pos=self.input_pos
227233
) # Self attention with input pos.
228234
tt_res = self.tt_mha(
229235
self.x, self.x, mask=mask, input_pos=self.input_pos
230236
) # Self attention with input pos.
231237

232-
self.assertTrue(torch.allclose(et_res, tt_res))
238+
assert_close(et_res, tt_res)
233239

234240
# Second run test kv cache read. Input pos is [10, 11, ..., 19]
235241
next_input_pos = torch.arange(10, 20).unsqueeze(0)
236242

237243
empty_y = torch.full_like(self.x, torch.nan)
238244
mask = self.causal_mask[next_input_pos, :]
239-
et_res = self.et_mha(
240-
self.x, empty_y, mask=mask, input_pos=next_input_pos
241-
) # Self attention with input pos.
242-
tt_res = self.tt_mha(
243-
self.x, None, mask=mask, input_pos=next_input_pos
244-
) # Self attention with input pos.
245+
et_res = self.et_mha(self.x, empty_y, mask=mask, input_pos=next_input_pos)
246+
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)
245247

246248
assert_close(et_res, tt_res)

0 commit comments

Comments
 (0)