4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
+ import tempfile
7
8
import unittest
8
9
from typing import Callable , Tuple
9
10
10
11
import torch
11
-
12
12
from executorch .exir import EdgeCompileConfig , to_edge
13
+
14
+ from executorch .extension .export_util .utils import save_pte_program
13
15
from executorch .extension .llm .modules .kv_cache import KVCache as InferenceKVCache
16
+
17
+ from executorch .extension .pybindings .portable_lib import (
18
+ _load_for_executorch_from_buffer ,
19
+ )
14
20
from executorch .runtime import Runtime
15
21
from torch .testing import assert_close
16
22
from torchtune .modules .kv_cache import KVCache
@@ -67,21 +73,10 @@ def _test_kv_cache(self, et_cache_module: Callable):
67
73
prefill_seq_len , self .batch_size , self .num_kv_heads , self .head_dim
68
74
)
69
75
70
- print ()
71
- print ("Prefilling..." )
72
- print ()
73
-
74
76
et_res = et_cache_module (k_val , v_val )
75
77
tt_res = self .tt_kv_cache .update (k_val_trans , v_val_trans )
76
78
tt_res_transposed = (tt_res [0 ].transpose (1 , 2 ), tt_res [1 ].transpose (1 , 2 ))
77
79
78
- print ()
79
- print ("Final tt kv_cache.cache_pos" )
80
- print (self .tt_kv_cache .cache_pos )
81
- print ("Final tt kv_cache.k_cache" )
82
- print (self .tt_kv_cache .k_cache )
83
- print ()
84
-
85
80
# Check torchtune matches executorch.
86
81
assert_close (et_res , tt_res_transposed )
87
82
@@ -112,7 +107,6 @@ def _test_kv_cache(self, et_cache_module: Callable):
112
107
113
108
self .assertTrue (et_k_cache [0 ][prefill_seq_len + 1 ][0 ][0 ] == 0 )
114
109
115
-
116
110
def export_kv_cache (
117
111
self ,
118
112
kv_cache : torch .nn .Module ,
@@ -179,9 +173,6 @@ def test_kv_cache_executorch(self):
179
173
)
180
174
et_program = edge_program .to_executorch ()
181
175
182
- """DEBUG the executorch program"""
183
- et_program .dump_executorch_program (verbose = True )
184
-
185
176
runtime = Runtime .get ()
186
177
program = runtime .load_program (et_program .buffer )
187
178
method = program .load_method ("forward" )
@@ -192,3 +183,27 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
192
183
193
184
self ._test_kv_cache (wrapped_callable )
194
185
186
+ def test_kv_cache_executorch_from_file (self ):
187
+ exported_kv_cache = self .export_kv_cache (self .et_kv_cache )
188
+ edge_program = to_edge (
189
+ exported_kv_cache ,
190
+ compile_config = EdgeCompileConfig (
191
+ _core_aten_ops_exception_list = [torch .ops .aten ._assert_async .msg ],
192
+ _check_ir_validity = False ,
193
+ ),
194
+ )
195
+ et_program = edge_program .to_executorch ()
196
+
197
+ with tempfile .TemporaryDirectory () as tempdir :
198
+ pte_path = save_pte_program (et_program , "test_et_kv_cache" , tempdir )
199
+ with open (pte_path , "rb" ) as f :
200
+ model_bytes = f .read ()
201
+ loaded_et_program = _load_for_executorch_from_buffer (model_bytes )
202
+
203
+ # Since method.execute expects a tuple of args.
204
+ def wrapped_callable (
205
+ k_val : torch .Tensor , v_val : torch .Tensor
206
+ ) -> torch .Tensor :
207
+ return loaded_et_program .forward ((k_val , v_val ))
208
+
209
+ self ._test_kv_cache (wrapped_callable )
0 commit comments