Skip to content

Commit 48f4eee

Browse files
authored
[exir] Enable dict for sym shape eval pass
Differential Revision: D61728068 Pull Request resolved: #4872
1 parent 26e921e commit 48f4eee

File tree

4 files changed

+75
-19
lines changed

4 files changed

+75
-19
lines changed

examples/models/llava/export_llava.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,15 @@
2323
replace_sdpa_with_custom_op,
2424
)
2525
from executorch.examples.models.llava.model import LlavaModel
26-
from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower
26+
from executorch.exir import (
27+
EdgeCompileConfig,
28+
ExecutorchBackendConfig,
29+
to_edge_transform_and_lower,
30+
)
31+
32+
from executorch.exir.passes import MemoryPlanningPass
33+
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
34+
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
2735

2836
from executorch.extension.llm.export.builder import DType, LLMEdgeManager
2937
from executorch.extension.llm.tokenizer.tokenizer import Tokenizer
@@ -199,7 +207,23 @@ def export_all(llava_model: LlavaModel):
199207
compile_config=EdgeCompileConfig(_check_ir_validity=False),
200208
)
201209

202-
executorch_program = lowered_and_edge.to_executorch()
210+
executorch_program = lowered_and_edge.to_executorch(
211+
ExecutorchBackendConfig(
212+
extract_constant_segment=True,
213+
extract_delegate_segments=True,
214+
passes=[
215+
QuantFusionPass(),
216+
],
217+
memory_planning_pass=MemoryPlanningPass("greedy", alloc_graph_input=False),
218+
sym_shape_eval_pass={
219+
"image_encoder": ConstraintBasedSymShapeEvalPass(),
220+
},
221+
)
222+
)
223+
for execution_plan in executorch_program._emitter_output.program.execution_plan:
224+
logging.info(
225+
f"Required memory for activation in bytes: {execution_plan.non_const_buffer_sizes}"
226+
)
203227
return executorch_program
204228

205229

@@ -253,13 +277,6 @@ def main():
253277

254278
with open(args.pte_name, "wb") as f:
255279
executorch_program.write_to_file(f)
256-
logging.info(
257-
"Required memory for activation in bytes: {}".format(
258-
executorch_program._emitter_output.program.execution_plan[
259-
0
260-
].non_const_buffer_sizes
261-
),
262-
)
263280
logging.info(f"Exported ExecuTorch program to {args.pte_name}")
264281

265282
# artifacts

examples/models/llava/test/test_pte.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
import sys
89

910
import torch
@@ -17,6 +18,10 @@
1718
from executorch.extension.llm.custom_ops import sdpa_with_kv_cache # noqa
1819

1920

21+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
22+
logging.basicConfig(level=logging.DEBUG, format=FORMAT)
23+
24+
2025
def main():
2126
args = sys.argv[1:]
2227
llava_module = _load_for_executorch(args[0])
@@ -41,26 +46,32 @@ def main():
4146
start_pos += pte_prefill_before_img.shape[1]
4247

4348
# pte prefill image
49+
logging.warning("Image encoder started")
4450
pte_embeds_img = llava_module.run_method("image_encoder", (resized,))[0]
51+
logging.warning("Image encoder finished")
52+
logging.warning("Image token prefill started")
4553
pte_prefill_img = llava_module.run_method(
4654
"text_model",
4755
(
4856
torch.tensor([start_pos], dtype=torch.int64),
4957
pte_embeds_img,
5058
),
5159
)[0]
60+
logging.warning("Image token prefill finished")
5261
print(pte_prefill_img)
5362

5463
start_pos += pte_prefill_img.shape[1]
5564

5665
# pte prefill prompt after img
66+
logging.warning("Text token prefill started")
5767
pte_embeds_after_img = llava_module.run_method(
5868
"token_embedding", (prompt_after_image,)
5969
)[0]
6070
pte_prefill_after_img = llava_module.run_method(
6171
"text_model",
6272
(torch.tensor([start_pos], dtype=torch.int64), pte_embeds_after_img),
6373
)[0]
74+
logging.warning("Text token prefill finished")
6475
print(pte_prefill_after_img)
6576

6677
# being tested, using llama_transformer

exir/capture/_config.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,12 @@ class ExecutorchBackendConfig:
8282
# If provided, the minimum alignment of delegate data in the program. Must
8383
# be a power of 2. If not provided, uses the value in the schema file.
8484
delegate_alignment: Optional[int] = None
85-
sym_shape_eval_pass: PassType = HintBasedSymShapeEvalPass()
85+
86+
# A single sym shape eval pass can be defined for all the programs in the
87+
# EdgeProgramManager or can be defined per program.
88+
sym_shape_eval_pass: Union[PassType, Dict[str, PassType]] = (
89+
HintBasedSymShapeEvalPass()
90+
)
8691

8792
# If set to true, view_copy operations will be converted to lightweight
8893
# view operations in the ET runtime

exir/program/_program.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
import torch
1515
import torch._export
16-
1716
from executorch.exir._serialize import _serialize_pte_binary
1817
from executorch.exir._serialize._cord import Cord
1918
from executorch.exir.backend.backend_api import to_backend
@@ -23,6 +22,7 @@
2322
from executorch.exir.emit._emitter import _DelegateDebugIdentifierMap
2423
from executorch.exir.error import ExportError
2524
from executorch.exir.graph_module import get_control_flow_submodules
25+
from executorch.exir.pass_base import PassBase
2626
from executorch.exir.pass_manager import PassType
2727
from executorch.exir.passes import (
2828
base_post_op_replace_passes,
@@ -641,25 +641,48 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram":
641641
return new_ep
642642

643643

644-
def pre_memory_planning_passes(config: ExecutorchBackendConfig) -> List[PassType]:
644+
def pre_memory_planning_passes(
645+
config: ExecutorchBackendConfig, name: Optional[str] = None
646+
) -> List[PassType]:
647+
"""
648+
Returns a list of passes to run before memory planning.
649+
Get the sym shape eval pass based on the method name, if the pass is not in the dict, use the default pass.
650+
"""
651+
# Handle symbolic shape eval pass
652+
if isinstance(config.sym_shape_eval_pass, dict):
653+
default_pass = ExecutorchBackendConfig().sym_shape_eval_pass
654+
if not name:
655+
sym_shape_eval_pass = default_pass
656+
# pyre-ignore: Undefined attribute [16]
657+
sym_shape_eval_pass = config.sym_shape_eval_pass.get(name, default_pass)
658+
elif isinstance(config.sym_shape_eval_pass, PassBase):
659+
sym_shape_eval_pass = config.sym_shape_eval_pass
660+
else:
661+
raise RuntimeError(
662+
f"sym_shape_eval_pass must be a dict or a PassBase, got {config.sym_shape_eval_pass}"
663+
)
645664
if config.remove_view_copy:
646-
# pyre-ignore
647665
return [
648666
NormalizeViewCopyBasePass(),
649667
dead_code_elimination_pass,
650668
ReplaceViewCopyWithViewPass(),
651-
config.sym_shape_eval_pass,
669+
sym_shape_eval_pass,
652670
config.to_out_var_pass,
653671
]
654672
else:
655-
# pyre-ignore
656673
return [
657-
config.sym_shape_eval_pass,
674+
sym_shape_eval_pass,
658675
config.to_out_var_pass,
659676
]
660677

661678

662-
def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]:
679+
def edge_to_executorch_passes(
680+
config: ExecutorchBackendConfig, name: Optional[str] = None
681+
) -> List[PassType]:
682+
"""
683+
Returns a list of passes to lower from edge to executorch.
684+
Get the pre memory planning passes based on the method name, if the pass is not in the dict, use the default pass.
685+
"""
663686
passes: List[PassType] = [
664687
*config.passes,
665688
SpecPropPass(),
@@ -668,7 +691,7 @@ def edge_to_executorch_passes(config: ExecutorchBackendConfig) -> List[PassType]
668691
# there exists an unbacked symint operation.
669692
EdgeToBackendOpsPass(),
670693
RemoveGraphAssertsPass(),
671-
] + pre_memory_planning_passes(config)
694+
] + pre_memory_planning_passes(config, name)
672695

673696
return passes
674697

@@ -1234,7 +1257,7 @@ def to_executorch(
12341257
program = unsafe_remove_auto_functionalized_pass(program)
12351258
gm, new_signature = insert_write_back_for_buffers_pass(program)
12361259
new_gm = program.graph_module
1237-
for p in edge_to_executorch_passes(config):
1260+
for p in edge_to_executorch_passes(config, name):
12381261
new_gm_res = p(new_gm)
12391262
assert new_gm_res is not None
12401263
new_gm = new_gm_res.graph_module

0 commit comments

Comments
 (0)