@@ -95,11 +95,7 @@ def test_attention_eager(self):
95
95
et_res = self .et_mha (self .x , self .x ) # Self attention.
96
96
tt_res = self .tt_mha (self .x , self .x ) # Self attention.
97
97
98
- assert_close (
99
- et_res ,
100
- tt_res ,
101
- msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
102
- )
98
+ assert_close (et_res , tt_res )
103
99
104
100
# test with kv cache
105
101
self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 20 )
@@ -130,11 +126,7 @@ def test_attention_eager(self):
130
126
self .x , self .x , input_pos = next_input_pos
131
127
) # Self attention with input pos.
132
128
133
- assert_close (
134
- et_res ,
135
- tt_res ,
136
- msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
137
- )
129
+ assert_close (et_res , tt_res )
138
130
139
131
def test_attention_export (self ):
140
132
# Self attention.
@@ -147,11 +139,7 @@ def test_attention_export(self):
147
139
et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
148
140
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
149
141
150
- assert_close (
151
- et_res ,
152
- tt_res ,
153
- msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res } " ,
154
- )
142
+ assert_close (et_res , tt_res )
155
143
156
144
# TODO: KV cache.
157
145
@@ -177,10 +165,6 @@ def test_attention_executorch(self):
177
165
et_res = method .execute ((self .x , self .x , self .input_pos ))
178
166
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
179
167
180
- assert_close (
181
- et_res [0 ],
182
- tt_res ,
183
- msg = f"TorchTune output is not close to ET output.\n \n TorchTune: { tt_res } \n ET output: { et_res [0 ]} " ,
184
- )
168
+ assert_close (et_res [0 ], tt_res )
185
169
186
170
# TODO: KV cache.
0 commit comments