Skip to content

Commit 4ee95d3

Browse files
committed
PR review
1 parent 61101c2 commit 4ee95d3

File tree

3 files changed

+13
-9
lines changed

3 files changed

+13
-9
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,7 @@
2323
import torch
2424

2525
from executorch.devtools.etrecord import generate_etrecord
26-
from executorch.exir.passes.cache_pos_init_mutable_pass import (
27-
CachePosToInitializedMutableBufferPass,
28-
)
26+
from executorch.exir.passes.init_mutable_pass import InitializedMutableBufferPass
2927

3028
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
3129

@@ -765,7 +763,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
765763

766764
additional_passes = []
767765
if args.model in TORCHTUNE_DEFINED_MODELS:
768-
additional_passes = [CachePosToInitializedMutableBufferPass()]
766+
additional_passes = [InitializedMutableBufferPass(["cache_pos"])]
769767
if args.generate_etrecord:
770768
if not builder_exported_to_edge.edge_manager:
771769
raise ValueError("Unable to generate etrecord due to missing edge manager.")

exir/passes/cache_pos_init_mutable_pass.py renamed to exir/passes/init_mutable_pass.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
from typing import List
9+
810
from executorch.exir.pass_base import ExportPass
911

1012

11-
class CachePosToInitializedMutableBufferPass(ExportPass):
13+
class InitializedMutableBufferPass(ExportPass):
1214
"""
1315
If the buffer has the name "cache_pos", such as in an kv_cache
1416
module with `self.register_buffer("cache_pos", torch.arange(10))`,
@@ -17,11 +19,13 @@ class CachePosToInitializedMutableBufferPass(ExportPass):
1719
an initialized state.
1820
"""
1921

20-
def __init__(self) -> None:
22+
def __init__(self, patterns: List[str]) -> None:
2123
super().__init__()
24+
self.patterns = patterns
2225

2326
def placeholder(self, name: str, arg, meta):
24-
if "cache_pos" in name:
25-
meta["et_init_buffer"] = True
27+
for pattern in self.patterns:
28+
if pattern in name:
29+
meta["et_init_buffer"] = True
2630

2731
return super().placeholder(name, arg, meta)

extension/llm/export/builder.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,9 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
396396

397397
return self
398398

399-
def to_executorch(self, passes: Optional[List[PassType]]) -> "LLMEdgeManager":
399+
def to_executorch(
400+
self, passes: Optional[List[PassType]] = None
401+
) -> "LLMEdgeManager":
400402
"""
401403
Lower the model to executorch and get an ExecutorchProgram.
402404
"""

0 commit comments

Comments
 (0)