Skip to content

Commit 15f0fdf

Browse files
jackzhxnglarryliu0820
authored andcommitted
Fix flaky ET attention test
1 parent 9cd57d6 commit 15f0fdf

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

extension/llm/modules/test/test_attention.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,10 @@ def test_attention_eager(self):
9494
et_res = self.et_mha(self.x, self.x) # Self attention.
9595
tt_res = self.tt_mha(self.x, self.x) # Self attention.
9696

97-
self.assertTrue(torch.allclose(et_res, tt_res))
97+
self.assertTrue(
98+
torch.allclose(et_res, tt_res),
99+
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
100+
)
98101

99102
# test with kv cache
100103
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
@@ -136,7 +139,10 @@ def test_attention_export(self):
136139
)
137140
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
138141
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
139-
self.assertTrue(torch.allclose(et_res, tt_res))
142+
self.assertTrue(
143+
torch.allclose(et_res, tt_res),
144+
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
145+
)
140146

141147
# TODO: KV cache.
142148

@@ -162,6 +168,9 @@ def test_attention_executorch(self):
162168
et_res = method.execute((self.x, self.x, self.input_pos))
163169
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
164170

165-
self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
171+
self.assertTrue(
172+
torch.allclose(et_res[0], tt_res, atol=1e-05),
173+
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res[0]}",
174+
)
166175

167176
# TODO: KV cache.

pytest.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ addopts =
3939
backends/xnnpack/test
4040
# extension/
4141
extension/llm/modules/test
42-
--ignore=extension/llm/modules/test/test_mha.py
4342
extension/pybindings/test
4443
# Runtime
4544
runtime

0 commit comments

Comments
 (0)