Skip to content

[llama-mm] Add unit tests for exporting MultiHeadAttention with KVCache #6801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions extension/llm/modules/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,8 @@ def update(
f", but found new key tensors with batch size {k_val.shape[0]}!"
)

assert (
self.cache_pos[0] + seq_len
) <= self.max_seq_len, f"self.cache_pos[0]: {self.cache_pos[0]} + seq_len: {seq_len} > self.max_seq_len: {self.max_seq_len}"
assert (self.cache_pos[0] + seq_len) <= self.max_seq_len

k_out = self.k_cache
v_out = self.v_cache

Expand Down
67 changes: 48 additions & 19 deletions extension/llm/modules/test/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os
import tempfile
import unittest

import torch
Expand All @@ -13,6 +15,7 @@
MultiHeadAttention as ETMultiHeadAttention,
)
from executorch.runtime import Runtime
from torch._inductor.package import load_package, package_aoti
from torch.testing import assert_close
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
Expand Down Expand Up @@ -130,34 +133,62 @@ def test_attention_eager(self):

def test_attention_export(self):
# Self attention.
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
)

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
)
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)

assert_close(et_res, tt_res)

# TODO: KV cache.

def test_attention_aoti(self):
# TODO.
pass
# Self attention.

# test with kv cache
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
with torch.no_grad():
so = torch._export.aot_compile(
self.et_mha,
args=(self.x, self.x),
kwargs={"input_pos": self.input_pos},
options={"aot_inductor.package": True},
dynamic_shapes=self.dynamic_shapes,
)
with tempfile.TemporaryDirectory() as tempdir:
path = package_aoti(os.path.join(tempdir, "mha.pt2"), so)
mha_aoti = load_package(path)

aoti_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
assert_close(aoti_res, tt_res)

def test_attention_executorch(self):
# Self attention.
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
)
# TODO: Fix kv cache
# self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
# self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)

with torch.no_grad():
et_mha_ep = torch.export.export(
self.et_mha,
(self.x, self.x),
kwargs={"input_pos": self.input_pos},
dynamic_shapes=self.dynamic_shapes,
)
et_program = to_edge(
et_mha_ep,
compile_config=EdgeCompileConfig(),
compile_config=EdgeCompileConfig(
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
),
).to_executorch()
runtime = Runtime.get()
program = runtime.load_program(et_program.buffer)
Expand All @@ -166,5 +197,3 @@ def test_attention_executorch(self):
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)

assert_close(et_res[0], tt_res)

# TODO: KV cache.
Loading