Skip to content

Commit a2b7ee3

Browse files
committed
Update attention test
1 parent 73591f1 commit a2b7ee3

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

extension/llm/modules/test/test_attention.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
import torch
1212
from executorch.exir import EdgeCompileConfig, to_edge
1313

14+
from executorch.exir.capture._config import ExecutorchBackendConfig
15+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
1416
from executorch.extension.llm.modules.attention import (
1517
MultiHeadAttention as ETMultiHeadAttention,
1618
)
@@ -114,7 +116,7 @@ def test_attention_eager(self):
114116
et_res = self.et_mha(self.x, self.x) # Self attention.
115117
tt_res = self.tt_mha(self.x, self.x) # Self attention.
116118

117-
self.assertTrue(torch.allclose(et_res, tt_res))
119+
assert_close(et_res, tt_res)
118120
self.et_mha.reset_cache()
119121
self.tt_mha.reset_cache()
120122

@@ -125,7 +127,7 @@ def test_attention_eager(self):
125127
self.x, self.x, input_pos=self.input_pos
126128
) # Self attention with input pos.
127129

128-
self.assertTrue(torch.allclose(et_res, tt_res))
130+
assert_close(et_res, tt_res)
129131

130132
# test kv cache read. Input pos can be [10, 11, ..., 19]
131133
next_input_pos = torch.arange(10, 20).unsqueeze(0)
@@ -187,9 +189,8 @@ def test_attention_aoti(self):
187189

188190
def test_attention_executorch(self):
189191
# Self attention.
190-
# TODO: Fix kv cache
191-
# self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192-
# self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192+
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
193+
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
193194

194195
with torch.no_grad():
195196
et_mha_ep = torch.export.export(
@@ -202,9 +203,15 @@ def test_attention_executorch(self):
202203
et_program = to_edge(
203204
et_mha_ep,
204205
compile_config=EdgeCompileConfig(
205-
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg]
206+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
207+
_check_ir_validity=False,
206208
),
207-
).to_executorch()
209+
).to_executorch(
210+
config=ExecutorchBackendConfig(
211+
passes=[InitializedMutableBufferPass(["cache_pos"])],
212+
)
213+
)
214+
208215
runtime = Runtime.get()
209216
program = runtime.load_program(et_program.buffer)
210217
method = program.load_method("forward")
@@ -219,28 +226,23 @@ def test_attention_torch_cond_eager(self):
219226
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
220227
self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=self.max_seq_len)
221228

222-
# mask
223229
mask = self.causal_mask[self.input_pos, :]
224-
# First run
230+
# First run.
225231
et_res = self.et_mha(
226232
self.x, self.x, mask=mask, input_pos=self.input_pos
227233
) # Self attention with input pos.
228234
tt_res = self.tt_mha(
229235
self.x, self.x, mask=mask, input_pos=self.input_pos
230236
) # Self attention with input pos.
231237

232-
self.assertTrue(torch.allclose(et_res, tt_res))
238+
assert_close(et_res, tt_res)
233239

234240
# Second run test kv cache read. Input pos is [10, 11, ..., 19]
235241
next_input_pos = torch.arange(10, 20).unsqueeze(0)
236242

237243
empty_y = torch.full_like(self.x, torch.nan)
238244
mask = self.causal_mask[next_input_pos, :]
239-
et_res = self.et_mha(
240-
self.x, empty_y, mask=mask, input_pos=next_input_pos
241-
) # Self attention with input pos.
242-
tt_res = self.tt_mha(
243-
self.x, None, mask=mask, input_pos=next_input_pos
244-
) # Self attention with input pos.
245+
et_res = self.et_mha(self.x, empty_y, mask=mask, input_pos=next_input_pos)
246+
tt_res = self.tt_mha(self.x, None, mask=mask, input_pos=next_input_pos)
245247

246248
assert_close(et_res, tt_res)

0 commit comments

Comments
 (0)