Skip to content

Commit 917fb0d

Browse files
committed
Fixes test but not model
1 parent aac90a0 commit 917fb0d

File tree

9 files changed

+84
-6
lines changed

9 files changed

+84
-6
lines changed

examples/models/llama3_2_vision/runner/native.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
2525

2626
# Note: import this after portable_lib
27-
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa # usort: skip
27+
from executorch.extension.llm.custom_ops import custom_ops # noqa # usort: skip
2828
from executorch.kernels import quantized # noqa
2929

3030

exir/emit/_emitter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1566,6 +1566,7 @@ def _find_fqn_for_placeholder(
15661566
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
15671567

15681568
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1569+
breakpoint()
15691570
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
15701571

15711572
# if the buffer is mutated then record that
@@ -1606,6 +1607,7 @@ def placeholder(
16061607

16071608
if isinstance(target, str) and isinstance(spec, TensorSpec):
16081609
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
1610+
print(f"fqn: {fqn}, is_mutable_buffer: {is_mutable_buffer}")
16091611

16101612
# If the placeholder has a constant_tag, it is external to the PTE file
16111613
# and requires a fqn and location=TensorDataLocation.EXTERNAL
@@ -1655,7 +1657,7 @@ def placeholder(
16551657
spec.storage = real_tensor.untyped_storage()
16561658

16571659
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1658-
spec.const = not (is_user_input or is_mutable_buffer)
1660+
spec.const = not is_user_input
16591661

16601662
evalue = (
16611663
self._tensor_spec_to_evalue(spec, constant_tag)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
9+
from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue
10+
from executorch.exir.passes.spec_prop_pass import make_spec
11+
12+
class InitMutableBufferPass(ExportPass):
13+
def __init__(self) -> None:
14+
super().__init__()
15+
16+
def update_placeholder_tensor_specs(
17+
self,
18+
exported_program: torch.export.ExportedProgram,
19+
graph_module: torch.fx.GraphModule,
20+
) -> None:
21+
"""
22+
Update the tensor specs for all placeholder nodes such that
23+
placeholders that are parameters are marked as constant.
24+
"""
25+
for node in graph_module.graph.nodes:
26+
if node.op != "placeholder":
27+
continue
28+
if "spec" not in node.meta:
29+
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
30+
# print(node)
31+
spec = node.meta["spec"]
32+
if (isinstance(node.target, str) and
33+
node.target in exported_program.graph_signature.inputs_to_buffers and exported_program.graph_signature.inputs_to_buffers[node.target] in exported_program.state_dict):
34+
# print(f"Setting {node.target}.const = True")
35+
# breakpoint()
36+
# print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]])
37+
spec.const = True
38+
39+
# pyre-ignore
40+
def placeholder(self, name: str, arg, meta):
41+
# print(name)
42+
meta["spec"] = make_spec(arg, const=meta.data['spec'].const)
43+
# if name == "b_kv_cache_cache_pos":
44+
# print("breakpoint")
45+
# breakpoint()
46+
47+
return super().placeholder(name, arg, meta)

exir/passes/spec_prop_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@
1818

1919

2020
# pyre-ignore
21-
def make_spec(x):
21+
def make_spec(x, const=False):
2222
if isinstance(x, torch.Tensor):
23-
return TensorSpec.from_tensor(x)
23+
return TensorSpec.from_tensor(x, const)
2424
elif isinstance(x, (int, bool, float)):
2525
return x
2626
else:

exir/program/_program.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from executorch.exir.passes.replace_view_copy_with_view_pass import (
4747
ReplaceViewCopyWithViewPass,
4848
)
49+
from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass
4950
from executorch.exir.passes.spec_prop_pass import SpecPropPass
5051
from executorch.exir.passes.weights_to_outputs_pass import weights_to_outputs_pass
5152
from executorch.exir.print_program import pretty_print, print_program
@@ -706,6 +707,7 @@ def edge_to_executorch_passes(
706707
passes: List[PassType] = [
707708
*config.passes,
708709
SpecPropPass(),
710+
InitMutableBufferPass(),
709711
# ExecuTorch backend ops are unable to handle unbacked symints. So after
710712
# this pass, passes cannot be Interpreter-based, because it will fail if
711713
# there exists an unbacked symint operation.
@@ -1352,6 +1354,8 @@ def to_executorch(
13521354
gm, new_signature = insert_write_back_for_buffers_pass(program)
13531355
new_gm = program.graph_module
13541356
for p in edge_to_executorch_passes(config, name):
1357+
if isinstance(p, InitMutableBufferPass):
1358+
p.update_placeholder_tensor_specs(program, new_gm)
13551359
new_gm_res = p(new_gm)
13561360
assert new_gm_res is not None
13571361
new_gm = new_gm_res.graph_module

extension/llm/export/builder.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ def to_executorch(self) -> "LLMEdgeManager":
414414
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
415415
)
416416
)
417+
print(self.export_program.to_executorch_program(verbose=True))
417418
logging.info(
418419
"Required memory for activation in bytes: {}".format(
419420
self.export_program._emitter_output.program.execution_plan[

extension/llm/modules/kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
5757
)
5858
self.register_buffer(
59-
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
59+
"cache_pos", torch.arange(0, self.max_seq_len), persistent=True
6060
)
6161
self.batch_size = batch_size
6262

extension/llm/modules/test/test_kv_cache.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,21 @@ def _test_kv_cache(self, et_cache_module: Callable):
6767
prefill_seq_len, self.batch_size, self.num_kv_heads, self.head_dim
6868
)
6969

70+
print()
71+
print("Prefilling...")
72+
print()
73+
7074
et_res = et_cache_module(k_val, v_val)
7175
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
7276
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))
7377

78+
print()
79+
print("Final tt kv_cache.cache_pos")
80+
print(self.tt_kv_cache.cache_pos)
81+
print("Final tt kv_cache.k_cache")
82+
print(self.tt_kv_cache.k_cache)
83+
print()
84+
7485
# Check torchtune matches executorch.
7586
assert_close(et_res, tt_res_transposed)
7687

@@ -89,17 +100,19 @@ def _test_kv_cache(self, et_cache_module: Callable):
89100

90101
et_res = et_cache_module(k_val, v_val)
91102
tt_res = self.tt_kv_cache.update(k_val_trans, v_val_trans)
103+
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))
92104

93105
# Check torchtune matches executorch.
94-
tt_res_transposed = (tt_res[0].transpose(1, 2), tt_res[1].transpose(1, 2))
95106
assert_close(tt_res_transposed, et_res)
96107

97108
# All rows should be filled with 1s up to 3 + 1th row.
98109
et_k_cache = et_res[0]
99110
for i in range(prefill_seq_len + 1):
100111
self.assertTrue(et_k_cache[0][i][0][0] == 1)
112+
101113
self.assertTrue(et_k_cache[0][prefill_seq_len + 1][0][0] == 0)
102114

115+
103116
def export_kv_cache(
104117
self,
105118
kv_cache: torch.nn.Module,
@@ -165,6 +178,10 @@ def test_kv_cache_executorch(self):
165178
),
166179
)
167180
et_program = edge_program.to_executorch()
181+
182+
"""DEBUG the executorch program"""
183+
et_program.dump_executorch_program(verbose=True)
184+
168185
runtime = Runtime.get()
169186
program = runtime.load_program(et_program.buffer)
170187
method = program.load_method("forward")
@@ -174,3 +191,4 @@ def wrapped_callable(k_val: torch.Tensor, v_val: torch.Tensor) -> torch.Tensor:
174191
return method.execute((k_val, v_val))
175192

176193
self._test_kv_cache(wrapped_callable)
194+

runtime/executor/method.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
#include <cinttypes> // @donotremove
1212
#include <cstdint>
1313
#include <cstdio>
14+
#include <iostream>
1415

16+
#include <executorch/extension/evalue_util/print_evalue.h>
1517
#include <executorch/runtime/backend/interface.h>
1618
#include <executorch/runtime/core/event_tracer_hooks.h>
1719
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
@@ -1179,6 +1181,10 @@ Error Method::execute_instruction() {
11791181
if (err == Error::Ok) {
11801182
step_state_.instr_idx = next_instr_idx;
11811183
}
1184+
1185+
// TODO: Print an EValue.
1186+
std::cout << "(" << values_[1] << " ) Printing kv_cache k_cache: " << executorch::extension::evalue_edge_items(9216) << values_[2] << std::endl;
1187+
11821188
return err;
11831189
}
11841190

0 commit comments

Comments
 (0)