Skip to content

Commit 3fd9963

Browse files
committed
[llama-mm] Add unit tests for exporting MultiHeadAttention with KVCache
Summary: To make sure we can always export MultiHeadAttention. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent f943856 commit 3fd9963

File tree

2 files changed

+50
-22
lines changed

2 files changed

+50
-22
lines changed

extension/llm/modules/kv_cache.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ def update(
105105
f", but found new key tensors with batch size {k_val.shape[0]}!"
106106
)
107107

108-
assert (
109-
self.cache_pos[0] + seq_len
110-
) <= self.max_seq_len, f"self.cache_pos[0]: {self.cache_pos[0]} + seq_len: {seq_len} > self.max_seq_len: {self.max_seq_len}"
108+
assert (self.cache_pos[0] + seq_len) <= self.max_seq_len
109+
111110
k_out = self.k_cache
112111
v_out = self.v_cache
113112

extension/llm/modules/test/test_attention.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
8+
import tempfile
79
import unittest
810

911
import torch
@@ -13,6 +15,7 @@
1315
MultiHeadAttention as ETMultiHeadAttention,
1416
)
1517
from executorch.runtime import Runtime
18+
from torch._inductor.package import load_package, package_aoti
1619
from torch.testing import assert_close
1720
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1821
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
@@ -130,34 +133,62 @@ def test_attention_eager(self):
130133

131134
def test_attention_export(self):
132135
# Self attention.
133-
et_mha_ep = torch.export.export(
134-
self.et_mha,
135-
(self.x, self.x),
136-
kwargs={"input_pos": self.input_pos},
137-
dynamic_shapes=self.dynamic_shapes,
138-
)
136+
137+
# test with kv cache
138+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
139+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
140+
with torch.no_grad():
141+
et_mha_ep = torch.export.export(
142+
self.et_mha,
143+
(self.x, self.x),
144+
kwargs={"input_pos": self.input_pos},
145+
dynamic_shapes=self.dynamic_shapes,
146+
)
139147
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
140148
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
141149

142150
assert_close(et_res, tt_res)
143151

144-
# TODO: KV cache.
145-
146152
def test_attention_aoti(self):
147-
# TODO.
148-
pass
153+
# Self attention.
154+
155+
# test with kv cache
156+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
157+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
158+
with torch.no_grad():
159+
so = torch._export.aot_compile(
160+
self.et_mha,
161+
args=(self.x, self.x),
162+
kwargs={"input_pos": self.input_pos},
163+
options={"aot_inductor.package": True},
164+
dynamic_shapes=self.dynamic_shapes,
165+
)
166+
with tempfile.TemporaryDirectory() as tempdir:
167+
path = package_aoti(os.path.join(tempdir, "mha.pt2"), so)
168+
mha_aoti = load_package(path)
169+
170+
et_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
171+
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
172+
self.assertTrue(torch.allclose(et_res, tt_res))
149173

150174
def test_attention_executorch(self):
151175
# Self attention.
152-
et_mha_ep = torch.export.export(
153-
self.et_mha,
154-
(self.x, self.x),
155-
kwargs={"input_pos": self.input_pos},
156-
dynamic_shapes=self.dynamic_shapes,
157-
)
176+
# TODO: Fix kv cache
177+
# self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
178+
# self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
179+
180+
with torch.no_grad():
181+
et_mha_ep = torch.export.export(
182+
self.et_mha,
183+
(self.x, self.x),
184+
kwargs={"input_pos": self.input_pos},
185+
dynamic_shapes=self.dynamic_shapes,
186+
)
158187
et_program = to_edge(
159188
et_mha_ep,
160-
compile_config=EdgeCompileConfig(),
189+
compile_config=EdgeCompileConfig(
190+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
191+
),
161192
).to_executorch()
162193
runtime = Runtime.get()
163194
program = runtime.load_program(et_program.buffer)
@@ -166,5 +197,3 @@ def test_attention_executorch(self):
166197
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
167198

168199
assert_close(et_res[0], tt_res)
169-
170-
# TODO: KV cache.

0 commit comments

Comments
 (0)