13
13
MultiHeadAttention as ETMultiHeadAttention ,
14
14
)
15
15
from executorch .runtime import Runtime
16
+ from torch .testing import assert_close
16
17
from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
17
18
from torchtune .modules .attention import MultiHeadAttention as TTMultiHeadAttention
18
19
@@ -94,8 +95,9 @@ def test_attention_eager(self):
94
95
et_res = self .et_mha (self .x , self .x ) # Self attention.
95
96
tt_res = self .tt_mha (self .x , self .x ) # Self attention.
96
97
97
- self .assertTrue (
98
- torch .allclose (et_res , tt_res ),
98
+ assert_close (
99
+ et_res ,
100
+ tt_res ,
99
101
msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
100
102
)
101
103
@@ -127,7 +129,12 @@ def test_attention_eager(self):
127
129
tt_res = self .tt_mha (
128
130
self .x , self .x , input_pos = next_input_pos
129
131
) # Self attention with input pos.
130
- self .assertTrue (torch .allclose (et_res , tt_res ))
132
+
133
+ assert_close (
134
+ et_res ,
135
+ tt_res ,
136
+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
137
+ )
131
138
132
139
def test_attention_export (self ):
133
140
# Self attention.
@@ -139,8 +146,10 @@ def test_attention_export(self):
139
146
)
140
147
et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
141
148
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
142
- self .assertTrue (
143
- torch .allclose (et_res , tt_res ),
149
+
150
+ assert_close (
151
+ et_res ,
152
+ tt_res ,
144
153
msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
145
154
)
146
155
@@ -168,8 +177,9 @@ def test_attention_executorch(self):
168
177
et_res = method .execute ((self .x , self .x , self .input_pos ))
169
178
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
170
179
171
- self .assertTrue (
172
- torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
180
+ assert_close (
181
+ et_res [0 ],
182
+ tt_res ,
173
183
msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
174
184
)
175
185
0 commit comments