13
13
MultiHeadAttention as ETMultiHeadAttention ,
14
14
)
15
15
from executorch .runtime import Runtime
16
+ from torch .testing import assert_close
16
17
from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
17
18
from torchtune .modules .attention import MultiHeadAttention as TTMultiHeadAttention
18
19
@@ -92,8 +93,9 @@ def test_attention_eager(self):
92
93
et_res = self .et_mha (self .x , self .x ) # Self attention.
93
94
tt_res = self .tt_mha (self .x , self .x ) # Self attention.
94
95
95
- self .assertTrue (
96
- torch .allclose (et_res , tt_res ),
96
+ assert_close (
97
+ et_res ,
98
+ tt_res ,
97
99
msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
98
100
)
99
101
@@ -104,7 +106,11 @@ def test_attention_eager(self):
104
106
# et_res = self.et_mha(self.x, self.x) # Self attention.
105
107
# tt_res = self.tt_mha(self.x, self.x) # Self attention.
106
108
107
- # self.assertTrue(torch.allclose(et_res, tt_res))
109
+ # assert_close(
110
+ # et_res,
111
+ # tt_res,
112
+ # msg=f"TorchTune output is not close to ET output.\n\nTorchTune: {tt_res}\nET output: {et_res}"
113
+ # )
108
114
109
115
def test_attention_export (self ):
110
116
# Self attention.
@@ -116,8 +122,10 @@ def test_attention_export(self):
116
122
)
117
123
et_res = et_mha_ep .module ()(self .x , self .x )
118
124
tt_res = self .tt_mha (self .x , self .x )
119
- self .assertTrue (
120
- torch .allclose (et_res , tt_res ),
125
+
126
+ assert_close (
127
+ et_res ,
128
+ tt_res ,
121
129
msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
122
130
)
123
131
@@ -145,8 +153,9 @@ def test_attention_executorch(self):
145
153
et_res = method .execute ((self .x , self .x ))
146
154
tt_res = self .tt_mha (self .x , self .x )
147
155
148
- self .assertTrue (
149
- torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
156
+ assert_close (
157
+ et_res [0 ],
158
+ tt_res ,
150
159
msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
151
160
)
152
161
0 commit comments