@@ -33,6 +33,7 @@ def setUp(self):
33
33
self .num_kv_heads = 8
34
34
self .head_dim = 64
35
35
self .max_seq_len = 128
36
+ self .encoder_max_seq_len = 128
36
37
self .rope_base = 500_000
37
38
self .scale_factor = 32
38
39
@@ -86,16 +87,26 @@ def setUp(self):
86
87
max_seq_len = self .max_seq_len ,
87
88
)
88
89
self .et_mha .load_state_dict (self .tt_mha .state_dict ())
90
+
89
91
# Common inputs.
90
92
seq_len = 10
91
93
self .x = torch .randn (1 , seq_len , self .embed_dim )
94
+ self .y = torch .randn (1 , seq_len , self .embed_dim )
92
95
self .input_pos = torch .arange (seq_len ).unsqueeze (0 ) # shape [1, seq_len]
93
- seq_len_dim = torch .export .Dim ("seq_len" , min = 1 , max = 100 )
94
- self .dynamic_shapes = (
95
- {0 : torch .export .Dim .STATIC , 1 : seq_len_dim , 2 : torch .export .Dim .STATIC },
96
- {0 : torch .export .Dim .STATIC , 1 : seq_len_dim , 2 : torch .export .Dim .STATIC },
97
- {0 : torch .export .Dim .STATIC , 1 : seq_len_dim },
98
- )
96
+ self .seq_len_dim = torch .export .Dim ("seq_len" , min = 1 , max = self .max_seq_len )
97
+ self .dynamic_shapes = {
98
+ "x" : {
99
+ 0 : torch .export .Dim .STATIC ,
100
+ 1 : self .seq_len_dim ,
101
+ 2 : torch .export .Dim .STATIC ,
102
+ },
103
+ "y" : {
104
+ 0 : torch .export .Dim .STATIC ,
105
+ 1 : self .seq_len_dim ,
106
+ 2 : torch .export .Dim .STATIC ,
107
+ },
108
+ "input_pos" : {0 : torch .export .Dim .STATIC , 1 : self .seq_len_dim },
109
+ }
99
110
self .causal_mask = torch .tril (
100
111
torch .ones (
101
112
size = (self .max_seq_len , self .max_seq_len ),
@@ -110,8 +121,8 @@ def test_attention_eager(self):
110
121
assert_close (et_res , tt_res )
111
122
112
123
# test with kv cache
113
- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 20 )
114
- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 20 )
124
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
125
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
115
126
116
127
et_res = self .et_mha (self .x , self .x ) # Self attention.
117
128
tt_res = self .tt_mha (self .x , self .x ) # Self attention.
@@ -144,12 +155,12 @@ def test_attention_export(self):
144
155
# Self attention.
145
156
146
157
# test with kv cache
147
- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
148
- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
158
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
159
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
149
160
with torch .no_grad ():
150
161
et_mha_ep = torch .export .export (
151
162
self .et_mha ,
152
- (self .x , self .x ),
163
+ (self .x , self .y ),
153
164
kwargs = {"input_pos" : self .input_pos },
154
165
dynamic_shapes = self .dynamic_shapes ,
155
166
strict = True ,
@@ -166,8 +177,8 @@ def test_attention_aoti(self):
166
177
# Self attention.
167
178
168
179
# test with kv cache
169
- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
170
- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
180
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
181
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
171
182
with torch .no_grad ():
172
183
so = torch ._export .aot_compile (
173
184
self .et_mha ,
@@ -189,13 +200,13 @@ def test_attention_aoti(self):
189
200
190
201
def test_attention_executorch (self ):
191
202
# Self attention.
192
- self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193
- self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
203
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
204
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self . max_seq_len )
194
205
195
206
with torch .no_grad ():
196
207
et_mha_ep = torch .export .export (
197
208
self .et_mha ,
198
- (self .x , self .x ),
209
+ (self .x , self .y ),
199
210
kwargs = {"input_pos" : self .input_pos },
200
211
dynamic_shapes = self .dynamic_shapes ,
201
212
strict = True ,
@@ -222,22 +233,18 @@ def test_attention_executorch(self):
222
233
223
234
def test_attention_torch_cond_eager (self ):
224
235
# Different from vanilla torchtune MHA, we rewrite the if condition with torch.cond. We need to make sure they are giving the same results regarding the if condition.
225
- # For the first run of MHA we provide `y` (self.x) but for the second run it will be a tensor full of nan.
236
+ # For the first run of MHA we provide `y` but for the second run it will be a tensor full of nan.
226
237
self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
227
238
self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
228
239
229
240
mask = self .causal_mask [self .input_pos , :]
230
241
# First run.
231
- et_res = self .et_mha (
232
- self .x , self .x , mask = mask , input_pos = self .input_pos
233
- ) # Self attention with input pos.
234
- tt_res = self .tt_mha (
235
- self .x , self .x , mask = mask , input_pos = self .input_pos
236
- ) # Self attention with input pos.
242
+ et_res = self .et_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
243
+ tt_res = self .tt_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
237
244
238
245
assert_close (et_res , tt_res )
239
246
240
- # Second run test kv cache read. Input pos is [10, 11, ..., 19]
247
+ # Second run tests kv cache read. Input pos is [10, 11, ..., 19]
241
248
next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
242
249
243
250
empty_y = torch .full_like (self .x , torch .nan )
@@ -246,3 +253,101 @@ def test_attention_torch_cond_eager(self):
246
253
tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
247
254
248
255
assert_close (et_res , tt_res )
256
+
257
+ def test_attention_torch_cond_export (self ):
258
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
259
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
260
+ mask = self .causal_mask [self .input_pos , :]
261
+ dynamic_shapes = {
262
+ ** self .dynamic_shapes ,
263
+ ** {
264
+ "mask" : {
265
+ 0 : torch .export .Dim .STATIC ,
266
+ 1 : self .seq_len_dim ,
267
+ 2 : torch .export .Dim .STATIC ,
268
+ }
269
+ },
270
+ }
271
+ with torch .no_grad ():
272
+ et_mha_ep = torch .export .export (
273
+ self .et_mha ,
274
+ (self .x , self .y ),
275
+ kwargs = {
276
+ "mask" : mask ,
277
+ "input_pos" : self .input_pos ,
278
+ },
279
+ dynamic_shapes = dynamic_shapes ,
280
+ strict = True ,
281
+ )
282
+
283
+ # First run.
284
+ et_res = et_mha_ep .module ()(self .x , self .y , mask = mask , input_pos = self .input_pos )
285
+ tt_res = self .tt_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
286
+
287
+ assert_close (et_res , tt_res )
288
+
289
+ # Second run tests kv cache read. Input pos is [10, 11, ..., 19]
290
+ next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
291
+ empty_y = torch .full_like (self .y , torch .nan )
292
+ mask = self .causal_mask [next_input_pos , :]
293
+ et_res = et_mha_ep .module ()(
294
+ self .x , empty_y , mask = mask , input_pos = next_input_pos
295
+ )
296
+ tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
297
+
298
+ assert_close (et_res , tt_res )
299
+
300
+ def test_attention_torch_cond_executorch (self ):
301
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
302
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
303
+ mask = self .causal_mask [self .input_pos , :]
304
+ dynamic_shapes = {
305
+ ** self .dynamic_shapes ,
306
+ ** {
307
+ "mask" : {
308
+ 0 : torch .export .Dim .STATIC ,
309
+ 1 : self .seq_len_dim ,
310
+ 2 : torch .export .Dim .STATIC ,
311
+ }
312
+ },
313
+ }
314
+ with torch .no_grad ():
315
+ et_mha_ep = torch .export .export (
316
+ self .et_mha ,
317
+ (self .x , self .y ),
318
+ kwargs = {
319
+ "mask" : mask ,
320
+ "input_pos" : self .input_pos ,
321
+ },
322
+ dynamic_shapes = dynamic_shapes ,
323
+ strict = True ,
324
+ )
325
+ et_program = to_edge (
326
+ et_mha_ep ,
327
+ compile_config = EdgeCompileConfig (
328
+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ],
329
+ _check_ir_validity = False ,
330
+ ),
331
+ ).to_executorch (
332
+ config = ExecutorchBackendConfig (
333
+ passes = [InitializedMutableBufferPass (["cache_pos" ])],
334
+ )
335
+ )
336
+
337
+ # First run.
338
+ runtime = Runtime .get ()
339
+ program = runtime .load_program (et_program .buffer )
340
+ method = program .load_method ("forward" )
341
+ et_res = method .execute ((self .x , self .y , mask , self .input_pos ))
342
+ tt_res = self .tt_mha (self .x , self .y , mask = mask , input_pos = self .input_pos )
343
+
344
+ assert_close (et_res [0 ], tt_res )
345
+
346
+ # Second run tests kv cache read. Input pos is [10, 11, ..., 19]
347
+ next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
348
+ empty_y = torch .full_like (self .y , torch .nan )
349
+ mask = self .causal_mask [next_input_pos , :]
350
+ et_res = method .execute ((self .x , empty_y , mask , next_input_pos ))
351
+ tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
352
+
353
+ assert_close (et_res [0 ], tt_res )
0 commit comments