Skip to content

Commit d066a25

Browse files
jackzhxnglarryliu0820
authored andcommitted
Use assert_close
1 parent 15f0fdf commit d066a25

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

extension/llm/modules/test/test_attention.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
MultiHeadAttention as ETMultiHeadAttention,
1414
)
1515
from executorch.runtime import Runtime
16+
from torch.testing import assert_close
1617
from torchtune.models.llama3_1._position_embeddings import Llama3ScaledRoPE
1718
from torchtune.modules.attention import MultiHeadAttention as TTMultiHeadAttention
1819

@@ -94,8 +95,9 @@ def test_attention_eager(self):
9495
et_res = self.et_mha(self.x, self.x) # Self attention.
9596
tt_res = self.tt_mha(self.x, self.x) # Self attention.
9697

97-
self.assertTrue(
98-
torch.allclose(et_res, tt_res),
98+
assert_close(
99+
et_res,
100+
tt_res,
99101
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
100102
)
101103

@@ -127,7 +129,12 @@ def test_attention_eager(self):
127129
tt_res = self.tt_mha(
128130
self.x, self.x, input_pos=next_input_pos
129131
) # Self attention with input pos.
130-
self.assertTrue(torch.allclose(et_res, tt_res))
132+
133+
assert_close(
134+
et_res,
135+
tt_res,
136+
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
137+
)
131138

132139
def test_attention_export(self):
133140
# Self attention.
@@ -139,8 +146,10 @@ def test_attention_export(self):
139146
)
140147
et_res = et_mha_ep.module()(self.x, self.x, input_pos=self.input_pos)
141148
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
142-
self.assertTrue(
143-
torch.allclose(et_res, tt_res),
149+
150+
assert_close(
151+
et_res,
152+
tt_res,
144153
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}",
145154
)
146155

@@ -168,8 +177,9 @@ def test_attention_executorch(self):
168177
et_res = method.execute((self.x, self.x, self.input_pos))
169178
tt_res = self.tt_mha(self.x, self.x, input_pos=self.input_pos)
170179

171-
self.assertTrue(
172-
torch.allclose(et_res[0], tt_res, atol=1e-05),
180+
assert_close(
181+
et_res[0],
182+
tt_res,
173183
msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res[0]}",
174184
)
175185

0 commit comments

Comments
 (0)