Skip to content

Commit 3cd4a8b

Browse files
committed
Use assert_close
1 parent 6d61e36 commit 3cd4a8b

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

extension/llm/modules/test/test_mha.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MultiHeadAttention as ETMultiHeadAttention,
1414
)
1515
from executorch.runtime import Runtime
16+
from torch.testing import assert_close
1617
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1718
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
1819

@@ -92,8 +93,9 @@ def test_attention_eager(self):
9293
et_res = self.et_mha(self.x, self.x) # Self attention.
9394
tt_res = self.tt_mha(self.x, self.x) # Self attention.
9495

95-
self.assertTrue(
96-
torch.allclose(et_res, tt_res),
96+
assert_close(
97+
et_res,
98+
tt_res,
9799
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
98100
)
99101

@@ -104,7 +106,11 @@ def test_attention_eager(self):
104106
# et_res = self.et_mha(self.x, self.x) # Self attention.
105107
# tt_res = self.tt_mha(self.x, self.x) # Self attention.
106108

107-
# self.assertTrue(torch.allclose(et_res, tt_res))
109+
# assert_close(
110+
# et_res,
111+
# tt_res,
112+
# msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}"
113+
# )
108114

109115
def test_attention_export(self):
110116
# Self attention.
@@ -116,8 +122,10 @@ def test_attention_export(self):
116122
)
117123
et_res = et_mha_ep.module()(self.x, self.x)
118124
tt_res = self.tt_mha(self.x, self.x)
119-
self.assertTrue(
120-
torch.allclose(et_res, tt_res),
125+
126+
assert_close(
127+
et_res,
128+
tt_res,
121129
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
122130
)
123131

@@ -145,8 +153,9 @@ def test_attention_executorch(self):
145153
et_res = method.execute((self.x, self.x))
146154
tt_res = self.tt_mha(self.x, self.x)
147155

148-
self.assertTrue(
149-
torch.allclose(et_res[0], tt_res, atol=1e-05),
156+
assert_close(
157+
et_res[0],
158+
tt_res,
150159
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res[0]}",
151160
)
152161

0 commit comments

Comments
 (0)