11
11
import torch
12
12
from executorch .exir import EdgeCompileConfig , to_edge
13
13
14
+ from executorch .exir .capture ._config import ExecutorchBackendConfig
15
+ from executorch .exir .passes .init_mutable_pass import InitializedMutableBufferPass
14
16
from executorch .extension .llm .modules .attention import (
15
17
MultiHeadAttention as ETMultiHeadAttention ,
16
18
)
@@ -114,7 +116,7 @@ def test_attention_eager(self):
114
116
et_res = self .et_mha (self .x , self .x ) # Self attention.
115
117
tt_res = self .tt_mha (self .x , self .x ) # Self attention.
116
118
117
- self . assertTrue ( torch . allclose ( et_res , tt_res ) )
119
+ assert_close ( et_res , tt_res )
118
120
self .et_mha .reset_cache ()
119
121
self .tt_mha .reset_cache ()
120
122
@@ -125,7 +127,7 @@ def test_attention_eager(self):
125
127
self .x , self .x , input_pos = self .input_pos
126
128
) # Self attention with input pos.
127
129
128
- self . assertTrue ( torch . allclose ( et_res , tt_res ) )
130
+ assert_close ( et_res , tt_res )
129
131
130
132
# test kv cache read. Input pos can be [10, 11, ..., 19]
131
133
next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
@@ -187,9 +189,8 @@ def test_attention_aoti(self):
187
189
188
190
def test_attention_executorch (self ):
189
191
# Self attention.
190
- # TODO: Fix kv cache
191
- # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192
- # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
192
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
193
194
194
195
with torch .no_grad ():
195
196
et_mha_ep = torch .export .export (
@@ -202,9 +203,15 @@ def test_attention_executorch(self):
202
203
et_program = to_edge (
203
204
et_mha_ep ,
204
205
compile_config = EdgeCompileConfig (
205
- _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ]
206
+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ],
207
+ _check_ir_validity = False ,
206
208
),
207
- ).to_executorch ()
209
+ ).to_executorch (
210
+ config = ExecutorchBackendConfig (
211
+ passes = [InitializedMutableBufferPass (["cache_pos" ])],
212
+ )
213
+ )
214
+
208
215
runtime = Runtime .get ()
209
216
program = runtime .load_program (et_program .buffer )
210
217
method = program .load_method ("forward" )
@@ -219,28 +226,23 @@ def test_attention_torch_cond_eager(self):
219
226
self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
220
227
self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = self .max_seq_len )
221
228
222
- # mask
223
229
mask = self .causal_mask [self .input_pos , :]
224
- # First run
230
+ # First run.
225
231
et_res = self .et_mha (
226
232
self .x , self .x , mask = mask , input_pos = self .input_pos
227
233
) # Self attention with input pos.
228
234
tt_res = self .tt_mha (
229
235
self .x , self .x , mask = mask , input_pos = self .input_pos
230
236
) # Self attention with input pos.
231
237
232
- self . assertTrue ( torch . allclose ( et_res , tt_res ) )
238
+ assert_close ( et_res , tt_res )
233
239
234
240
# Second run test kv cache read. Input pos is [10, 11, ..., 19]
235
241
next_input_pos = torch .arange (10 , 20 ).unsqueeze (0 )
236
242
237
243
empty_y = torch .full_like (self .x , torch .nan )
238
244
mask = self .causal_mask [next_input_pos , :]
239
- et_res = self .et_mha (
240
- self .x , empty_y , mask = mask , input_pos = next_input_pos
241
- ) # Self attention with input pos.
242
- tt_res = self .tt_mha (
243
- self .x , None , mask = mask , input_pos = next_input_pos
244
- ) # Self attention with input pos.
245
+ et_res = self .et_mha (self .x , empty_y , mask = mask , input_pos = next_input_pos )
246
+ tt_res = self .tt_mha (self .x , None , mask = mask , input_pos = next_input_pos )
245
247
246
248
assert_close (et_res , tt_res )
0 commit comments