4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import os
8
+ import tempfile
7
9
import unittest
8
10
9
11
import torch
13
15
MultiHeadAttention as ETMultiHeadAttention ,
14
16
)
15
17
from executorch .runtime import Runtime
18
+ from torch ._inductor .package import load_package , package_aoti
16
19
from torch .testing import assert_close
17
20
from torchtune .models .llama3_1 ._position_embeddings import Llama3ScaledRoPE
18
21
from torchtune .modules .attention import MultiHeadAttention as TTMultiHeadAttention
@@ -130,34 +133,62 @@ def test_attention_eager(self):
130
133
131
134
def test_attention_export (self ):
132
135
# Self attention.
133
- et_mha_ep = torch .export .export (
134
- self .et_mha ,
135
- (self .x , self .x ),
136
- kwargs = {"input_pos" : self .input_pos },
137
- dynamic_shapes = self .dynamic_shapes ,
138
- )
136
+
137
+ # test with kv cache
138
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
139
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
140
+ with torch .no_grad ():
141
+ et_mha_ep = torch .export .export (
142
+ self .et_mha ,
143
+ (self .x , self .x ),
144
+ kwargs = {"input_pos" : self .input_pos },
145
+ dynamic_shapes = self .dynamic_shapes ,
146
+ )
139
147
et_res = et_mha_ep .module ()(self .x , self .x , input_pos = self .input_pos )
140
148
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
141
149
142
150
assert_close (et_res , tt_res )
143
151
144
- # TODO: KV cache.
145
-
146
152
def test_attention_aoti (self ):
147
- # TODO.
148
- pass
153
+ # Self attention.
154
+
155
+ # test with kv cache
156
+ self .et_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
157
+ self .tt_mha .setup_cache (1 , dtype = torch .float32 , max_seq_len = 100 )
158
+ with torch .no_grad ():
159
+ so = torch ._export .aot_compile (
160
+ self .et_mha ,
161
+ args = (self .x , self .x ),
162
+ kwargs = {"input_pos" : self .input_pos },
163
+ options = {"aot_inductor.package" : True },
164
+ dynamic_shapes = self .dynamic_shapes ,
165
+ )
166
+ with tempfile .TemporaryDirectory () as tempdir :
167
+ path = package_aoti (os .path .join (tempdir , "mha.pt2" ), so )
168
+ mha_aoti = load_package (path )
169
+
170
+ et_res = mha_aoti (self .x , self .x , input_pos = self .input_pos )
171
+ tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
172
+ self .assertTrue (torch .allclose (et_res , tt_res ))
149
173
150
174
def test_attention_executorch (self ):
151
175
# Self attention.
152
- et_mha_ep = torch .export .export (
153
- self .et_mha ,
154
- (self .x , self .x ),
155
- kwargs = {"input_pos" : self .input_pos },
156
- dynamic_shapes = self .dynamic_shapes ,
157
- )
176
+ # TODO: Fix kv cache
177
+ # self.et_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
178
+ # self.tt_mha.setup_cache(1, dtype=torch.float32, max_seq_len=100)
179
+
180
+ with torch .no_grad ():
181
+ et_mha_ep = torch .export .export (
182
+ self .et_mha ,
183
+ (self .x , self .x ),
184
+ kwargs = {"input_pos" : self .input_pos },
185
+ dynamic_shapes = self .dynamic_shapes ,
186
+ )
158
187
et_program = to_edge (
159
188
et_mha_ep ,
160
- compile_config = EdgeCompileConfig (),
189
+ compile_config = EdgeCompileConfig (
190
+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ]
191
+ ),
161
192
).to_executorch ()
162
193
runtime = Runtime .get ()
163
194
program = runtime .load_program (et_program .buffer )
@@ -166,5 +197,3 @@ def test_attention_executorch(self):
166
197
tt_res = self .tt_mha (self .x , self .x , input_pos = self .input_pos )
167
198
168
199
assert_close (et_res [0 ], tt_res )
169
-
170
- # TODO: KV cache.
0 commit comments