4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import os
8
+ import tempfile
7
9
import unittest
8
10
9
11
import torch
13
15
MultiHeadAttention as ETMultiHeadAttention ,
14
16
)
15
17
from executorch .runtime import Runtime
18
+ from torch ._inductor .package import load_package , package_aoti
16
19
from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
17
20
from torchtune .modules .attention import MultiHeadAttention as TTMultiHeadAttention
18
21
@@ -128,33 +131,61 @@ def test_attention_eager(self):
128
131
129
132
def test_attention_export (self ):
130
133
# Self attention.
131
- et_mha_ep = torch .export .export (
132
- self .et_mha ,
133
- (self .x , self .x ),
134
- kwargs = {"input_pos" : self .input_pos },
135
- dynamic_shapes = self .dynamic_shapes ,
136
- )
134
+
135
+ # test with kv cache
136
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
137
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
138
+ with torch .no_grad ():
139
+ et_mha_ep = torch .export .export (
140
+ self .et_mha ,
141
+ (self .x , self .x ),
142
+ kwargs = {"input_pos" : self .input_pos },
143
+ dynamic_shapes = self .dynamic_shapes ,
144
+ )
137
145
et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
138
146
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
139
147
self .assertTrue (torch .allclose (et_res , tt_res ))
140
148
141
- # TODO: KV cache.
142
-
143
149
def test_attention_aoti (self ):
144
- # TODO.
145
- pass
150
+ # Self attention.
151
+
152
+ # test with kv cache
153
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
154
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
155
+ with torch .no_grad ():
156
+ so = torch ._export .aot_compile (
157
+ self .et_mha ,
158
+ args = (self .x , self .x ),
159
+ kwargs = {"input_pos" : self .input_pos },
160
+ options = {"aot_inductor.package" : True },
161
+ dynamic_shapes = self .dynamic_shapes ,
162
+ )
163
+ with tempfile .TemporaryDirectory () as tempdir :
164
+ path = package_aoti (os .path .join (tempdir , "mha.pt2" ), so )
165
+ mha_aoti = load_package (path )
166
+
167
+ et_res = mha_aoti (self .x , self .x , input_pos = self .input_pos )
168
+ tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
169
+ self .assertTrue (torch .allclose (et_res , tt_res ))
146
170
147
171
def test_attention_executorch (self ):
148
172
# Self attention.
149
- et_mha_ep = torch .export .export (
150
- self .et_mha ,
151
- (self .x , self .x ),
152
- kwargs = {"input_pos" : self .input_pos },
153
- dynamic_shapes = self .dynamic_shapes ,
154
- )
173
+ # TODO: Fix kv cache
174
+ # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
175
+ # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
176
+
177
+ with torch .no_grad ():
178
+ et_mha_ep = torch .export .export (
179
+ self .et_mha ,
180
+ (self .x , self .x ),
181
+ kwargs = {"input_pos" : self .input_pos },
182
+ dynamic_shapes = self .dynamic_shapes ,
183
+ )
155
184
et_program = to_edge (
156
185
et_mha_ep ,
157
- compile_config = EdgeCompileConfig (),
186
+ compile_config = EdgeCompileConfig (
187
+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ]
188
+ ),
158
189
).to_executorch ()
159
190
runtime = Runtime .get ()
160
191
program = runtime .load_program (et_program .buffer )
@@ -163,5 +194,3 @@ def test_attention_executorch(self):
163
194
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
164
195
165
196
self .assertTrue (torch .allclose (et_res [0 ], tt_res , atol = 1e-06 ))
166
-
167
- # TODO: KV cache.
0 commit comments