@@ -94,7 +94,10 @@ def test_attention_eager(self):
94
94
et_res = self .et_mha (self .x , self .x ) # Self attention.
95
95
tt_res = self .tt_mha (self .x , self .x ) # Self attention.
96
96
97
- self .assertTrue (torch .allclose (et_res , tt_res ))
97
+ self .assertTrue (
98
+ torch .allclose (et_res , tt_res ),
99
+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
100
+ )
98
101
99
102
# test with kv cache
100
103
self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 20 )
@@ -136,7 +139,10 @@ def test_attention_export(self):
136
139
)
137
140
et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
138
141
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
139
- self .assertTrue (torch .allclose (et_res , tt_res ))
142
+ self .assertTrue (
143
+ torch .allclose (et_res , tt_res ),
144
+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
145
+ )
140
146
141
147
# TODO: KV cache.
142
148
@@ -162,6 +168,9 @@ def test_attention_executorch(self):
162
168
et_res = method .execute ((self .x , self .x , self .input_pos ))
163
169
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
164
170
165
- self .assertTrue (torch .allclose (et_res [0 ], tt_res , atol = 1e-06 ))
171
+ self .assertTrue (
172
+ torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
173
+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
174
+ )
166
175
167
176
# TODO: KV cache.
0 commit comments