Skip to content

Fix executorch kv cache incompatibility with to_executorch lowering #7279

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 19 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from executorch.devtools.etrecord import generate_etrecord
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass

from executorch.extension.llm.export.builder import DType, LLMEdgeManager

Expand Down Expand Up @@ -775,6 +776,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
for partitioner in partitioners:
logging.info(f"--> {partitioner.__class__.__name__}")

additional_passes = []
if args.model in TORCHTUNE_DEFINED_MODELS:
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
if args.generate_etrecord:
if not builder_exported_to_edge.edge_manager:
raise ValueError("Unable to generate etrecord due to missing edge manager.")
Expand All @@ -789,7 +793,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(
passes=additional_passes,
)

# Generate ETRecord
if edge_manager_copy:
Expand All @@ -807,7 +813,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
canonicalize_program(builder.edge_manager.exported_program())

builder = builder.to_executorch()
builder = builder.to_executorch(passes=additional_passes)

if args.profile_memory:
generate_memory_trace(builder.export_program, "memory_profile.json")
Expand Down
18 changes: 15 additions & 3 deletions examples/models/llama3_2_vision/runner/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
TorchTuneLlamaRunner,
)

from executorch.extension.pybindings.portable_lib import _load_for_executorch
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)

# Load custom ops and quantized ops.
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

# Note: import this after portable_lib
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
from executorch.kernels import quantized # noqa


Expand All @@ -43,7 +45,17 @@ def __init__(self, args):
use_kv_cache=args.kv_cache,
vocab_size=params["vocab_size"],
)
self.model = _load_for_executorch(args.pte)
# Save the loaded model bytes to prevent data from going out of
# scope after the `with` and getting cleaned up by Python's
# garbage collector.
self.model_bytes = None
with open(args.pte, "rb") as f:
self.model_bytes = f.read()
# Need to use _load_for_executorch_from_buffer instead of
# _load_for_executorch because the latter uses MmapDataLoader,
# which doesn't have load_into() implemented, which is needed
# for loading initialized mutable buffers.
self.model = _load_for_executorch_from_buffer(self.model_bytes)
self.use_kv_cache = args.kv_cache

def forward(
Expand Down
9 changes: 7 additions & 2 deletions exir/emit/_emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1575,7 +1575,8 @@ def _find_fqn_for_placeholder(
warnings.warn(
"Mutation on a buffer in the model is detected. ExecuTorch assumes "
"buffers that are mutated in the graph have a meaningless initial state, "
"only the shape and dtype will be serialized.",
"only the shape and dtype will be serialized, unless a pass which sets "
'meta["et_init_buffer"] to True such as InitializedMutableBufferPass is run.',
UserWarning,
stacklevel=1,
)
Expand All @@ -1602,6 +1603,7 @@ def placeholder(
"""
spec = self.node.meta["spec"]
constant_tag = self.node.meta.get("constant_tag", None)
initialize_buffer = self.node.meta.get("et_init_buffer", None)
is_user_input = True

if isinstance(target, str) and isinstance(spec, TensorSpec):
Expand Down Expand Up @@ -1655,7 +1657,10 @@ def placeholder(
spec.storage = real_tensor.untyped_storage()

# User inputs and mutable buffers are not constants, other buffers or parameters are.
spec.const = not (is_user_input or is_mutable_buffer)
if initialize_buffer and is_mutable_buffer:
spec.const = True
else:
spec.const = not (is_user_input or is_mutable_buffer)

evalue = (
self._tensor_spec_to_evalue(spec, constant_tag)
Expand Down
53 changes: 53 additions & 0 deletions exir/emit/test/test_emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import typing
import unittest
from contextlib import contextmanager
from copy import deepcopy
from typing import List, Optional, Tuple

import executorch.exir as exir
Expand All @@ -31,6 +32,7 @@
from executorch.exir.error import InternalError
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.constant_prop_pass import constant_prop_pass
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.print_program import pretty_print, print_program # noqa
from executorch.exir.schema import (
Expand All @@ -56,6 +58,7 @@
from executorch.extension.pybindings.portable_lib import (
_load_for_executorch_from_buffer,
)
from executorch.runtime import Runtime

from functorch.experimental import control_flow
from torch import nn
Expand Down Expand Up @@ -243,6 +246,56 @@ def forward(self, x):
)
self.assertIsInstance(program.execution_plan[0].values[outputs[6]].val, Null)

def test_initialized_mutable_buffer(self):
"""Test that mutable buffers can hold meaningful initialized state."""

class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# Mutable buffer with non-empty initial state.
self.register_buffer("cache_pos", torch.arange(0, 10))

def forward(self, x):
self.cache_pos.add_(1)
return self.cache_pos

m = TestModule()
example_inputs = (torch.ones(10),)
ep = torch.export.export(m, example_inputs)
edge = to_edge(
ep,
compile_config=EdgeCompileConfig(
_check_ir_validity=False,
),
)

# Save a copy of the edge program since to_executorch is
# stateful to some degree.
edge_copy = deepcopy(edge)
et_config = ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
et_program_init_pass = edge.to_executorch(config=et_config)
et_program_regular = edge_copy.to_executorch()

runtime = Runtime.get()
program_init_pass = runtime.load_program(et_program_init_pass.buffer)
method_init_pass = program_init_pass.load_method("forward")

program_regular = runtime.load_program(et_program_regular.buffer)
method_regular = program_regular.load_method("forward")

# Test that the mutable buffer is initialized.
torch.allclose(
method_init_pass.execute((example_inputs))[0], torch.arange(1, 11)
)
# Test that the mutable buffer is uninitialized and starts with default zeros,
# we test equality with torch.ones because of the mutation += 1 in the model forward.
torch.allclose(
method_regular.execute((example_inputs))[0],
torch.ones(10, dtype=torch.int64),
Copy link
Contributor

@dbort dbort Jan 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be zeros, based on the comment? If not, please update the comment to clarify why this is ones. And if it should be zeros, did this test fail?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, it's because in the forward of the model we do self.cache_pos += 1, I'll specify this

)

def test_int_list_input(self):
class M(torch.nn.Module):
def forward(self, x, y, z):
Expand Down
32 changes: 32 additions & 0 deletions exir/passes/init_mutable_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from typing import List

from executorch.exir.pass_base import ExportPass


class InitializedMutableBufferPass(ExportPass):
"""
If a buffer has a name that within a specified list, set meta["et_init_buffer"]
to True, which provides the mutable buffer with an initialized state.
As an example, a module with `self.register_buffer("cache_pos", torch.arange(10))`
when patterns = ["cache_pos"] would have its initial state set instead of being
left uninitialized by default.
"""

def __init__(self, patterns: List[str]) -> None:
super().__init__()
self.patterns = patterns

def placeholder(self, name: str, arg, meta):
for pattern in self.patterns:
if pattern in name:
meta["et_init_buffer"] = True

return super().placeholder(name, arg, meta)
23 changes: 15 additions & 8 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from executorch.exir.backend.utils import format_delegated_graph
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig

from executorch.exir.pass_manager import PassType
from executorch.exir.passes import MemoryPlanningPass
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
Expand Down Expand Up @@ -415,21 +416,27 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag

return self

def to_executorch(self) -> "LLMEdgeManager":
def to_executorch(
self, passes: Optional[List[PassType]] = None
) -> "LLMEdgeManager":
"""
Lower the model to executorch and get an ExecutorchProgram.
"""
assert self.edge_manager, "Need to run export_to_edge() first"
to_executorch_passes = [
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
]
if passes:
to_executorch_passes.extend(passes)

self.export_program = self.edge_manager.to_executorch(
ExecutorchBackendConfig(
extract_delegate_segments=True,
passes=[
# If there are Linear operations left in the graph, let's execute
# them with the optimized op_linear rather than materializing a
# transpose followed by a regular op_mm.
ConvertToLinearPass(),
QuantFusionPass(),
],
passes=to_executorch_passes,
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
)
Expand Down
34 changes: 18 additions & 16 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import torch
from executorch.exir import EdgeCompileConfig, to_edge

from executorch.exir.capture._config import ExecutorchBackendConfig
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
from executorch.extension.llm.modules.attention import (
MultiHeadAttention as ETMultiHeadAttention,
)
Expand Down Expand Up @@ -114,7 +116,7 @@ def test_attention_eager(self):
et_res = self.et_mha(self.x, self.x) # Self attention.
tt_res = self.tt_mha(self.x, self.x) # Self attention.

self.assertTrue(torch.allclose(et_res, tt_res))
assert_close(et_res, tt_res)
self.et_mha.reset_cache()
self.tt_mha.reset_cache()

Expand All @@ -125,7 +127,7 @@ def test_attention_eager(self):
self.x, self.x, input_pos=self.input_pos
) # Self attention with input pos.

self.assertTrue(torch.allclose(et_res, tt_res))
assert_close(et_res, tt_res)

# test kv cache read. Input pos can be [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)
Expand Down Expand Up @@ -187,9 +189,8 @@ def test_attention_aoti(self):

def test_attention_executorch(self):
# Self attention.
# TODO: Fix kv cache
# self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
# self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)

with torch.no_grad():
et_mha_ep = torch.export.export(
Expand All @@ -202,9 +203,15 @@ def test_attention_executorch(self):
et_program = to_edge(
et_mha_ep,
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
_check_ir_validity=False,
),
).to_executorch()
).to_executorch(
config=ExecutorchBackendConfig(
passes=[InitializedMutableBufferPass(["cache_pos"])],
)
)

runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
method = program.load_method("forward")
Expand All @@ -219,28 +226,23 @@ def test_attention_torch_cond_eager(self):
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)

# mask
mask = self.causal_mask[self.input_pos, :]
# First run
# First run.
et_res = self.et_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, self.x, mask=mask, input_pos=self.input_pos
) # Self attention with input pos.

self.assertTrue(torch.allclose(et_res, tt_res))
assert_close(et_res, tt_res)

# Second run test kv cache read. Input pos is [10, 11, ..., 19]
next_input_pos = torch.arange(10, 20).unsqueeze(0)

empty_y = torch.full_like(self.x, torch.nan)
mask = self.causal_mask[next_input_pos, :]
et_res = self.et_mha(
self.x, empty_y, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.
tt_res = self.tt_mha(
self.x, None, mask=mask, input_pos=next_input_pos
) # Self attention with input pos.
et_res = self.et_mha(self.x, empty_y, mask=mask, input_pos=next_input_pos)
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)

assert_close(et_res, tt_res)
Loading