Skip to content

Commit d538d43

Browse files
committed
Updated pass
1 parent 917fb0d commit d538d43

File tree

6 files changed

+17
-37
lines changed

6 files changed

+17
-37
lines changed

examples/models/llama3_2_vision/runner/native.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
)
2020

2121
from executorch.extension.pybindings.portable_lib import _load_for_executorch
22+
from executorch.extension.pybindings.portable_lib import _load_for_executorch_from_buffer
2223

2324
# Load custom ops and quantized ops.
2425
from executorch.extension.pybindings import portable_lib # noqa # usort: skip
@@ -43,7 +44,10 @@ def __init__(self, args):
4344
use_kv_cache=args.kv_cache,
4445
vocab_size=params["vocab_size"],
4546
)
46-
self.model = _load_for_executorch(args.pte)
47+
with open(args.pte, "rb") as f:
48+
model_bytes = f.read()
49+
self.model = _load_for_executorch_from_buffer(model_bytes)
50+
# self.model = _load_for_executorch(args.pte)
4751
self.use_kv_cache = args.kv_cache
4852

4953
def forward(

exir/emit/_emitter.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,7 +1566,6 @@ def _find_fqn_for_placeholder(
15661566
fqn = self.exported_program.graph_signature.inputs_to_parameters[target]
15671567

15681568
elif target in self.exported_program.graph_signature.inputs_to_buffers:
1569-
breakpoint()
15701569
fqn = self.exported_program.graph_signature.inputs_to_buffers[target]
15711570

15721571
# if the buffer is mutated then record that
@@ -1603,6 +1602,7 @@ def placeholder(
16031602
"""
16041603
spec = self.node.meta["spec"]
16051604
constant_tag = self.node.meta.get("constant_tag", None)
1605+
initialize_buffer = self.node.meta.get("et_init_buffer", None)
16061606
is_user_input = True
16071607

16081608
if isinstance(target, str) and isinstance(spec, TensorSpec):
@@ -1657,7 +1657,11 @@ def placeholder(
16571657
spec.storage = real_tensor.untyped_storage()
16581658

16591659
# User inputs and mutable buffers are not constants, other buffers or parameters are.
1660-
spec.const = not is_user_input
1660+
if initialize_buffer:
1661+
assert is_mutable_buffer
1662+
spec.const = True
1663+
else:
1664+
spec.const = not (is_user_input or is_mutable_buffer)
16611665

16621666
evalue = (
16631667
self._tensor_spec_to_evalue(spec, constant_tag)

exir/passes/init_mutable_buffer_pass.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,35 +13,9 @@ class InitMutableBufferPass(ExportPass):
1313
def __init__(self) -> None:
1414
super().__init__()
1515

16-
def update_placeholder_tensor_specs(
17-
self,
18-
exported_program: torch.export.ExportedProgram,
19-
graph_module: torch.fx.GraphModule,
20-
) -> None:
21-
"""
22-
Update the tensor specs for all placeholder nodes such that
23-
placeholders that are parameters are marked as constant.
24-
"""
25-
for node in graph_module.graph.nodes:
26-
if node.op != "placeholder":
27-
continue
28-
if "spec" not in node.meta:
29-
raise RuntimeError(f"Placeholder node {node} missing meta['spec']")
30-
# print(node)
31-
spec = node.meta["spec"]
32-
if (isinstance(node.target, str) and
33-
node.target in exported_program.graph_signature.inputs_to_buffers and exported_program.graph_signature.inputs_to_buffers[node.target] in exported_program.state_dict):
34-
# print(f"Setting {node.target}.const = True")
35-
# breakpoint()
36-
# print(exported_program.state_dict[exported_program.graph_signature.inputs_to_buffers[node.target]])
37-
spec.const = True
38-
39-
# pyre-ignore
4016
def placeholder(self, name: str, arg, meta):
41-
# print(name)
42-
meta["spec"] = make_spec(arg, const=meta.data['spec'].const)
43-
# if name == "b_kv_cache_cache_pos":
44-
# print("breakpoint")
45-
# breakpoint()
46-
17+
if "cache_pos" in name:
18+
meta["et_init_buffer"] = True
19+
4720
return super().placeholder(name, arg, meta)
21+

exir/program/_program.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,8 +1354,6 @@ def to_executorch(
13541354
gm, new_signature = insert_write_back_for_buffers_pass(program)
13551355
new_gm = program.graph_module
13561356
for p in edge_to_executorch_passes(config, name):
1357-
if isinstance(p, InitMutableBufferPass):
1358-
p.update_placeholder_tensor_specs(program, new_gm)
13591357
new_gm_res = p(new_gm)
13601358
assert new_gm_res is not None
13611359
new_gm = new_gm_res.graph_module

extension/llm/export/builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def to_executorch(self) -> "LLMEdgeManager":
414414
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
415415
)
416416
)
417-
print(self.export_program.to_executorch_program(verbose=True))
417+
print(self.export_program.dump_executorch_program(verbose=True))
418418
logging.info(
419419
"Required memory for activation in bytes: {}".format(
420420
self.export_program._emitter_output.program.execution_plan[

extension/llm/modules/kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
"v_cache", torch.zeros(cache_shape, dtype=dtype), persistent=False
5757
)
5858
self.register_buffer(
59-
"cache_pos", torch.arange(0, self.max_seq_len), persistent=True
59+
"cache_pos", torch.arange(0, self.max_seq_len), persistent=False
6060
)
6161
self.batch_size = batch_size
6262

0 commit comments

Comments
 (0)