Skip to content

Commit 9cdfb43

Browse files
committed
Lint
1 parent 5db136c commit 9cdfb43

File tree

4 files changed

+37
-20
lines changed

4 files changed

+37
-20
lines changed

examples/models/llama3_2_vision/runner/native.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
TorchTuneLlamaRunner,
1919
)
2020

21-
from executorch.extension.pybindings.portable_lib import _load_for_executorch
22-
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer
21+
from executorch.extension.pybindings.portable_lib import (
22+
_load_for_executorch,
23+
_load_for_executorch_from_buffer,
24+
)
2325

2426
# Load custom ops and quantized ops.
2527
from executorch.extension.pybindings import portable_lib # noqa # usort: skip

exir/passes/init_mutable_buffer_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
1010
from executorch.exir.passes.spec_prop_pass import make_spec
1111

12+
1213
class InitMutableBufferPass(ExportPass):
1314
def __init__(self) -> None:
1415
super().__init__()
@@ -18,4 +19,3 @@ def placeholder(self, name: str, arg, meta):
1819
meta["et_init_buffer"] = True
1920

2021
return super().placeholder(name, arg, meta)
21-

exir/program/_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
OpReplacePass,
3535
)
3636
from executorch.exir.passes.external_constants_pass import external_constants_pass
37+
from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass
3738
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3839
insert_write_back_for_buffers_pass,
3940
)
@@ -46,7 +47,6 @@
4647
from executorch.exir.passes.replace_view_copy_with_view_pass import (
4748
ReplaceViewCopyWithViewPass,
4849
)
49-
from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass
5050
from executorch.exir.passes.spec_prop_pass import SpecPropPass
5151
from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
5252
from executorch.exir.print_program import pretty_print, print_program

extension/llm/modules/test/test_kv_cache.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,19 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import tempfile
78
import unittest
89
from typing import Callable, Tuple
910

1011
import torch
11-
1212
from executorch.exir import EdgeCompileConfig, to_edge
13+
14+
from executorch.extension.export_util.utils import save_pte_program
1315
from executorch.extension.llm.modules.kv_cache import KVCache as InferenceKVCache
16+
17+
from executorch.extension.pybindings.portable_lib import (
18+
_load_for_executorch_from_buffer,
19+
)
1420
from executorch.runtime import Runtime
1521
from torch.testing import assert_close
1622
from torchtune.modules.kv_cache import KVCache
@@ -67,21 +73,10 @@ def _test_kv_cache(self, et_cache_module: Callable):
6773
prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim
6874
)
6975

70-
print()
71-
print("Prefilling...")
72-
print()
73-
7476
et_res = et_cache_module(k_val, v_val)
7577
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
7678
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))
7779

78-
print()
79-
print("Final tt kv_cache.cache_pos")
80-
print(self.tt_kv_cache.cache_pos)
81-
print("Final tt kv_cache.k_cache")
82-
print(self.tt_kv_cache.k_cache)
83-
print()
84-
8580
# Check torchtune matches executorch.
8681
assert_close(et_res, tt_res_transposed)
8782

@@ -112,7 +107,6 @@ def _test_kv_cache(self, et_cache_module: Callable):
112107

113108
self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0)
114109

115-
116110
def export_kv_cache(
117111
self,
118112
kv_cache: torch.nn.Module,
@@ -179,9 +173,6 @@ def test_kv_cache_executorch(self):
179173
)
180174
et_program = edge_program.to_executorch()
181175

182-
"""DEBUG the executorch program"""
183-
et_program.dump_executorch_program(verbose=True)
184-
185176
runtime = Runtime.get()
186177
program = runtime.load_program(et_program.buffer)
187178
method = program.load_method("forward")
@@ -192,3 +183,27 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
192183

193184
self._test_kv_cache(wrapped_callable)
194185

186+
def test_kv_cache_executorch_from_file(self):
187+
exported_kv_cache = self.export_kv_cache(self.et_kv_cache)
188+
edge_program = to_edge(
189+
exported_kv_cache,
190+
compile_config=EdgeCompileConfig(
191+
_core_aten_ops_exception_list=[torch.ops.aten._assert_async.msg],
192+
_check_ir_validity=False,
193+
),
194+
)
195+
et_program = edge_program.to_executorch()
196+
197+
with tempfile.TemporaryDirectory() as tempdir:
198+
pte_path = save_pte_program(et_program, "test_et_kv_cache", tempdir)
199+
with open(pte_path, "rb") as f:
200+
model_bytes = f.read()
201+
loaded_et_program = _load_for_executorch_from_buffer(model_bytes)
202+
203+
# Since method.execute expects a tuple of args.
204+
def wrapped_callable(
205+
k_val: torch.Tensor, v_val: torch.Tensor
206+
) -> torch.Tensor:
207+
return loaded_et_program.forward((k_val, v_val))
208+
209+
self._test_kv_cache(wrapped_callable)

0 commit comments

Comments
 (0)