@@ -92,7 +92,10 @@ def test_attention_eager(self):
92
92
et_res = self .et_mha (self .x , self .x ) # Self attention.
93
93
tt_res = self .tt_mha (self .x , self .x ) # Self attention.
94
94
95
- self .assertTrue (torch .allclose (et_res , tt_res ))
95
+ self .assertTrue (
96
+ torch .allclose (et_res , tt_res ),
97
+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
98
+ )
96
99
97
100
# TODO: KV cache.
98
101
# self.et_mha.setup_cache(1, dtype=torch.float16, max_seq_len=20)
@@ -113,7 +116,10 @@ def test_attention_export(self):
113
116
)
114
117
et_res = et_mha_ep .module ()(self .x , self .x )
115
118
tt_res = self .tt_mha (self .x , self .x )
116
- self .assertTrue (torch .allclose (et_res , tt_res ))
119
+ self .assertTrue (
120
+ torch .allclose (et_res , tt_res ),
121
+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
122
+ )
117
123
118
124
# TODO: KV cache.
119
125
@@ -139,6 +145,9 @@ def test_attention_executorch(self):
139
145
et_res = method .execute ((self .x , self .x ))
140
146
tt_res = self .tt_mha (self .x , self .x )
141
147
142
- self .assertTrue (torch .allclose (et_res [0 ], tt_res , atol = 1e-06 ))
148
+ self .assertTrue (
149
+ torch .allclose (et_res [0 ], tt_res , atol = 1e-05 ),
150
+ msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
151
+ )
143
152
144
153
# TODO: KV cache.
0 commit comments