Skip to content

Commit 05308af

Browse files
committed
[llama-mm] Add unit tests for exporting MultiHeadAttention with KVCache
Summary: To make sure we can always export MultiHeadAttention. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 9cd57d6 commit 05308af

File tree

1 file changed

+48
-19
lines changed

1 file changed

+48
-19
lines changed

extension/llm/modules/test/test_attention.py

Lines changed: 48 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
8+
import tempfile
79
import unittest
810

911
import torch
@@ -13,6 +15,7 @@
1315
MultiHeadAttention as ETMultiHeadAttention,
1416
)
1517
from executorch.runtime import Runtime
18+
from torch._inductor.package import load_package, package_aoti
1619
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1720
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
1821

@@ -128,33 +131,61 @@ def test_attention_eager(self):
128131

129132
def test_attention_export(self):
130133
# Self attention.
131-
et_mha_ep = torch.export.export(
132-
self.et_mha,
133-
(self.x, self.x),
134-
kwargs={"input_pos": self.input_pos},
135-
dynamic_shapes=self.dynamic_shapes,
136-
)
134+
135+
# test with kv cache
136+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
137+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
138+
with torch.no_grad():
139+
et_mha_ep = torch.export.export(
140+
self.et_mha,
141+
(self.x, self.x),
142+
kwargs={"input_pos": self.input_pos},
143+
dynamic_shapes=self.dynamic_shapes,
144+
)
137145
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
138146
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
139147
self.assertTrue(torch.allclose(et_res, tt_res))
140148

141-
# TODO: KV cache.
142-
143149
def test_attention_aoti(self):
144-
# TODO.
145-
pass
150+
# Self attention.
151+
152+
# test with kv cache
153+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
154+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
155+
with torch.no_grad():
156+
so = torch._export.aot_compile(
157+
self.et_mha,
158+
args=(self.x, self.x),
159+
kwargs={"input_pos": self.input_pos},
160+
options={"aot_inductor.package": True},
161+
dynamic_shapes=self.dynamic_shapes,
162+
)
163+
with tempfile.TemporaryDirectory() as tempdir:
164+
path = package_aoti(os.path.join(tempdir, "mha.pt2"), so)
165+
mha_aoti = load_package(path)
166+
167+
et_res = mha_aoti(self.x, self.x, input_pos=self.input_pos)
168+
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
169+
self.assertTrue(torch.allclose(et_res, tt_res))
146170

147171
def test_attention_executorch(self):
148172
# Self attention.
149-
et_mha_ep = torch.export.export(
150-
self.et_mha,
151-
(self.x, self.x),
152-
kwargs={"input_pos": self.input_pos},
153-
dynamic_shapes=self.dynamic_shapes,
154-
)
173+
# TODO: Fix kv cache
174+
# self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
175+
# self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
176+
177+
with torch.no_grad():
178+
et_mha_ep = torch.export.export(
179+
self.et_mha,
180+
(self.x, self.x),
181+
kwargs={"input_pos": self.input_pos},
182+
dynamic_shapes=self.dynamic_shapes,
183+
)
155184
et_program = to_edge(
156185
et_mha_ep,
157-
compile_config=EdgeCompileConfig(),
186+
compile_config=EdgeCompileConfig(
187+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
188+
),
158189
).to_executorch()
159190
runtime = Runtime.get()
160191
program = runtime.load_program(et_program.buffer)
@@ -163,5 +194,3 @@ def test_attention_executorch(self):
163194
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
164195

165196
self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
166-
167-
# TODO: KV cache.

0 commit comments

Comments
 (0)