Skip to content

Commit f943856

Browse files
Fix flaky ET attention test (#6795)
* Fix flaky ET attention test * Use assert_close * Remove msg from assert_close Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --------- Co-authored-by: Mengwei Liu <[email protected]>
1 parent b6ebd3c commit f943856

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

extension/llm/modules/test/test_attention.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MultiHeadAttention as ETMultiHeadAttention,
1414
)
1515
from executorch.runtime import Runtime
16+
from torch.testing import assert_close
1617
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1718
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
1819

@@ -94,7 +95,7 @@ def test_attention_eager(self):
9495
et_res = self.et_mha(self.x, self.x) # Self attention.
9596
tt_res = self.tt_mha(self.x, self.x) # Self attention.
9697

97-
self.assertTrue(torch.allclose(et_res, tt_res))
98+
assert_close(et_res, tt_res)
9899

99100
# test with kv cache
100101
self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=20)
@@ -124,7 +125,8 @@ def test_attention_eager(self):
124125
tt_res = self.tt_mha(
125126
self.x, self.x, input_pos=next_input_pos
126127
) # Self attention with input pos.
127-
self.assertTrue(torch.allclose(et_res, tt_res))
128+
129+
assert_close(et_res, tt_res)
128130

129131
def test_attention_export(self):
130132
# Self attention.
@@ -136,7 +138,8 @@ def test_attention_export(self):
136138
)
137139
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
138140
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
139-
self.assertTrue(torch.allclose(et_res, tt_res))
141+
142+
assert_close(et_res, tt_res)
140143

141144
# TODO: KV cache.
142145

@@ -162,6 +165,6 @@ def test_attention_executorch(self):
162165
et_res = method.execute((self.x, self.x, self.input_pos))
163166
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
164167

165-
self.assertTrue(torch.allclose(et_res[0], tt_res, atol=1e-06))
168+
assert_close(et_res[0], tt_res)
166169

167170
# TODO: KV cache.

pytest.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ addopts =
3939
backends/xnnpack/test
4040
# extension/
4141
extension/llm/modules/test
42-
--ignore=extension/llm/modules/test/test_mha.py
4342
extension/pybindings/test
4443
# Runtime
4544
runtime

0 commit comments

Comments
 (0)