Skip to content

Commit 5dcb8f7

Browse files
committed
Only add pass when vision model
1 parent 9cdfb43 commit 5dcb8f7

File tree

6 files changed

+23
-37
lines changed

6 files changed

+23
-37
lines changed

examples/models/llama/export_llama_lib.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import torch
2424

2525
from executorch.devtools.etrecord import generate_etrecord
26+
from executorch.exir.passes.cache_pos_init_mutable_pass import (
27+
CachePosToInitializedMutableBufferPass,
28+
)
2629

2730
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
2831

@@ -760,6 +763,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
760763
for partitioner in partitioners:
761764
logging.info(f"--> {partitioner.__class__.__name__}")
762765

766+
additional_passes = []
767+
if args.model in TORCHTUNE_DEFINED_MODELS:
768+
additional_passes = [CachePosToInitializedMutableBufferPass()]
763769
if args.generate_etrecord:
764770
if not builder_exported_to_edge.edge_manager:
765771
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -774,7 +780,9 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
774780
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
775781
canonicalize_program(builder.edge_manager.exported_program())
776782

777-
builder = builder.to_executorch()
783+
builder = builder.to_executorch(
784+
passes=additional_passes,
785+
)
778786

779787
# Generate ETRecord
780788
if edge_manager_copy:
@@ -792,7 +800,7 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
792800
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
793801
canonicalize_program(builder.edge_manager.exported_program())
794802

795-
builder = builder.to_executorch()
803+
builder = builder.to_executorch(passes=additional_passes)
796804

797805
if args.profile_memory:
798806
generate_memory_trace(builder.export_program, "memory_profile.json")

examples/models/llama3_2_vision/runner/native.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020

2121
from executorch.extension.pybindings.portable_lib import (
22-
_load_for_executorch,
2322
_load_for_executorch_from_buffer,
2423
)
2524

@@ -50,7 +49,6 @@ def __init__(self, args):
5049
with open(args.pte, "rb") as f:
5150
self.model_bytes = f.read()
5251
self.model = _load_for_executorch_from_buffer(self.model_bytes)
53-
# self.model = _load_for_executorch(args.pte)
5452
self.use_kv_cache = args.kv_cache
5553

5654
def forward(

exir/emit/_emitter.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,6 @@ def placeholder(
16071607

16081608
if isinstance(target, str) and isinstance(spec, TensorSpec):
16091609
fqn, is_mutable_buffer = self._find_fqn_for_placeholder(target, spec)
1610-
print(f"fqn: {fqn}, is_mutable_buffer: {is_mutable_buffer}")
16111610

16121611
# If the placeholder has a constant_tag, it is external to the PTE file
16131612
# and requires a fqn and location=TensorDataLocation.EXTERNAL

exir/passes/init_mutable_buffer_pass.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

exir/program/_program.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
OpReplacePass,
3535
)
3636
from executorch.exir.passes.external_constants_pass import external_constants_pass
37-
from executorch.exir.passes.init_mutable_buffer_pass import InitMutableBufferPass
3837
from executorch.exir.passes.insert_write_back_for_buffers_pass import (
3938
insert_write_back_for_buffers_pass,
4039
)
@@ -707,7 +706,6 @@ def edge_to_executorch_passes(
707706
passes: List[PassType] = [
708707
*config.passes,
709708
SpecPropPass(),
710-
InitMutableBufferPass(),
711709
# ExecuTorch backend ops are unable to handle unbacked symints. So after
712710
# this pass, passes cannot be Interpreter-based, because it will fail if
713711
# there exists an unbacked symint operation.

extension/llm/export/builder.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from executorch.exir.backend.utils import format_delegated_graph
2626
from executorch.exir.capture._config import EdgeCompileConfig, ExecutorchBackendConfig
2727

28+
from executorch.exir.pass_manager import PassType
2829
from executorch.exir.passes import MemoryPlanningPass
2930
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
3031
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
@@ -395,26 +396,29 @@ def to_backend(self, partitioners: Optional[List[Partitioner]]) -> "LLMEdgeManag
395396

396397
return self
397398

398-
def to_executorch(self) -> "LLMEdgeManager":
399+
def to_executorch(self, passes: Optional[List[PassType]]) -> "LLMEdgeManager":
399400
"""
400401
Lower the model to executorch and get an ExecutorchProgram.
401402
"""
402403
assert self.edge_manager, "Need to run export_to_edge() first"
404+
to_executorch_passes = [
405+
# If there are Linear operations left in the graph, let's execute
406+
# them with the optimized op_linear rather than materializing a
407+
# transpose followed by a regular op_mm.
408+
ConvertToLinearPass(),
409+
QuantFusionPass(),
410+
]
411+
if passes:
412+
to_executorch_passes.extend(passes)
413+
403414
self.export_program = self.edge_manager.to_executorch(
404415
ExecutorchBackendConfig(
405416
extract_delegate_segments=True,
406-
passes=[
407-
# If there are Linear operations left in the graph, let's execute
408-
# them with the optimized op_linear rather than materializing a
409-
# transpose followed by a regular op_mm.
410-
ConvertToLinearPass(),
411-
QuantFusionPass(),
412-
],
417+
passes=passes,
413418
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
414419
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
415420
)
416421
)
417-
print(self.export_program.dump_executorch_program(verbose=True))
418422
logging.info(
419423
"Required memory for activation in bytes: {}".format(
420424
self.export_program._emitter_output.program.execution_plan[

0 commit comments

Comments
 (0)