Skip to content

Commit a2bc6bd

Browse files
pytorchbotEashan Garg
andauthored
Change memory planning API to accept full algorithm as argument as opposed to string name (#6130)
Change memory planning API to accept full algorithm as argument as opposed to string name (#4727) Summary: Pull Request resolved: #4727 Executorch memory planning currently accepts a string identifier to represent the desired algorithm. However, this makes it difficult to pass custom arguments to write more customized memory planning algorithms. This change allows users to pass the full memory planning function as an argument as opposed to just the string identifier. Core changes in: - fbcode/executorch/exir/passes/memory_planning_pass.py - fbcode/executorch/exir/tests/test_memory_planning.py Remaining changes are just to enforce compliance with new API at all call sites in codebase NOTE: A less intrusive change could be to allow argument to be either string or entire custom functions. I opted for just passing only functions to simplify and avoid confusion Reviewed By: zonglinpeng, hsharma35, mcremon-meta Differential Revision: D60433641 fbshipit-source-id: 0fe3677b7c3f4c3763cb1b4fe6d28ef814f2ecf9 (cherry picked from commit 618466e) Co-authored-by: Eashan Garg <[email protected]>
1 parent 67c959a commit a2bc6bd

File tree

22 files changed

+38
-91
lines changed

22 files changed

+38
-91
lines changed

backends/qualcomm/tests/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def lower_module_and_test_output(
350350
# Therefore, won't want to pre-allocate
351351
# by memory manager in runtime.
352352
memory_planning_pass=MemoryPlanningPass(
353-
memory_planning_algo="greedy",
354353
alloc_graph_input=not self.shared_buffer,
355354
alloc_graph_output=not self.shared_buffer,
356355
),

backends/vulkan/vulkan_preprocess.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def preprocess( # noqa: C901
5757
MeanToSumDiv(),
5858
SpecPropPass(),
5959
ConstraintBasedSymShapeEvalPass(),
60-
MemoryPlanningPass("greedy"),
60+
MemoryPlanningPass(),
6161
]
6262

6363
new_gm = program.graph_module

docs/source/compiler-memory-planning.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ The `MemoryPlanningPass` exposes the option to not memory plan program inputs an
3232
program = edge_program.to_executorch(
3333
exir.ExecutorchBackendConfig(
3434
memory_planning_pass=MemoryPlanningPass(
35-
memory_planning_algo="greedy",
3635
alloc_graph_input=False, # Inputs will not be memory planned, the data_ptr for input tensors after model load will be nullptr
3736
alloc_graph_output=True, # Outputs will be memory planned, the data_ptr for output tensors after model load will be in the `planned_memory`.
3837
)
@@ -77,7 +76,7 @@ Then later when lowering to ExecuTorch you can use your custom plan in the follo
7776
program = edge_program.to_executorch(
7877
exir.ExecutorchBackendConfig(
7978
memory_planning_pass=CustomPoolMemoryPlanningPass(
80-
memory_planning_algo="greedy",
79+
memory_planning_algo=greedy,
8180
)
8281
)
8382
)

docs/source/tutorials_source/export-to-executorch-tutorial.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,7 @@ def forward(self, a, x, b):
523523
executorch_program: ExecutorchProgramManager = edge_program.to_executorch(
524524
ExecutorchBackendConfig(
525525
passes=[], # User-defined passes
526-
memory_planning_pass=MemoryPlanningPass(
527-
"greedy"
528-
), # Default memory planning pass
526+
memory_planning_pass=MemoryPlanningPass(), # Default memory planning pass
529527
)
530528
)
531529

examples/mediatek/model_export_scripts/llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -365,7 +365,6 @@ def export_to_et_ir(
365365
executorch_program = delegated_program.to_executorch(
366366
config=exir.ExecutorchBackendConfig(
367367
memory_planning_pass=exir.passes.MemoryPlanningPass(
368-
memory_planning_algo="greedy",
369368
alloc_graph_input=False,
370369
alloc_graph_output=False,
371370
),

examples/models/llava/export_llava.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def export_all(llava_model: LlavaModel):
233233
passes=[
234234
QuantFusionPass(),
235235
],
236-
memory_planning_pass=MemoryPlanningPass("greedy", alloc_graph_input=False),
236+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
237237
sym_shape_eval_pass={
238238
"image_encoder": ConstraintBasedSymShapeEvalPass(),
239239
"text_model": ConstraintBasedSymShapeEvalPass(),

examples/qualcomm/oss_scripts/llama2/llama.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,6 @@ def lowering_modules(
311311
# Therefore, won't want to pre-allocate
312312
# by memory manager in runtime.
313313
memory_planning_pass=MemoryPlanningPass(
314-
memory_planning_algo="greedy",
315314
alloc_graph_input=False,
316315
alloc_graph_output=False,
317316
),

examples/qualcomm/qaihub_scripts/utils/export.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,6 @@ def compile(args):
220220
)
221221
# setup memory planning
222222
memory_planning_pass = MemoryPlanningPass(
223-
memory_planning_algo="greedy",
224223
alloc_graph_input=args.allocate_graph_io,
225224
alloc_graph_output=args.allocate_graph_io,
226225
)

examples/qualcomm/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ def build_executorch_binary(
285285
# Therefore, won't want to pre-allocate
286286
# by memory manager in runtime.
287287
memory_planning_pass=MemoryPlanningPass(
288-
memory_planning_algo="greedy",
289288
alloc_graph_input=not shared_buffer,
290289
alloc_graph_output=not shared_buffer,
291290
),

exir/capture/_config.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ class ExecutorchBackendConfig:
5656

5757
# A single memory planning pass can be defined for all the programs in the
5858
# EdgeProgramManager or can be defined per program.
59-
memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass(
60-
"greedy"
61-
)
59+
memory_planning_pass: Union[PassType, Dict[str, PassType]] = MemoryPlanningPass()
6260
to_out_var_pass: PassType = ToOutVarPass(ignore_to_out_var_failure=False)
6361
dynamic_memory_planning_mode: DynamicMemoryPlanningMode = (
6462
DynamicMemoryPlanningMode.UPPER_BOUND

exir/emit/test/test_emit.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,6 @@ def forward(self, k: torch.Tensor) -> torch.Tensor:
11451145
config = exir.ExecutorchBackendConfig(
11461146
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
11471147
memory_planning_pass=MemoryPlanningPass(
1148-
memory_planning_algo="greedy",
11491148
# allow_lifetime_and_storage_overlap: bool = False,
11501149
alloc_graph_input=True,
11511150
alloc_graph_output=False,
@@ -1606,9 +1605,7 @@ def forward(self, x):
16061605
)
16071606
model = model.to_executorch(
16081607
config=ExecutorchBackendConfig(
1609-
memory_planning_pass=MemoryPlanningPass(
1610-
"greedy", alloc_graph_input=False
1611-
),
1608+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
16121609
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
16131610
)
16141611
)

exir/lowered_backend_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ def program(
326326
verifiers=[lowered_exported_program.verifier],
327327
)
328328
if memory_planning is None:
329-
memory_planning = MemoryPlanningPass("greedy")
329+
memory_planning = MemoryPlanningPass()
330330
exported_program = _transform(exported_program, SpecPropPass(), memory_planning)
331331
emitted_program = emit_program(
332332
exported_program, emit_stacktrace=emit_stacktrace

exir/memory_planning.py

Lines changed: 1 addition & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,7 @@
1818
from executorch.exir import memory
1919
from executorch.exir.control_flow import while_loop as exir_while
2020
from executorch.exir.delegate import executorch_call_delegate
21-
from executorch.exir.error import (
22-
ExportError,
23-
ExportErrorType,
24-
internal_assert,
25-
InternalError,
26-
)
21+
from executorch.exir.error import internal_assert, InternalError
2722
from executorch.exir.operator.convert import is_inplace_variant, is_out_variant
2823
from executorch.exir.schema import TensorShapeDynamism
2924
from executorch.exir.tensor import TensorSpec
@@ -255,17 +250,6 @@ def verify_graph_input_output(self) -> None:
255250
), f"Misallocate graph output {graph_output_allocated} v.s. {self.alloc_graph_output}"
256251

257252

258-
def register_algo(fn: Callable[..., List[int]]) -> Callable[..., List[int]]:
259-
algo_name = fn.__name__
260-
if algo_name in REGISTERED_ALGOS:
261-
raise ExportError(
262-
ExportErrorType.VIOLATION_OF_SPEC,
263-
f"Re-registering memory planning algorithm {algo_name}",
264-
)
265-
REGISTERED_ALGOS[algo_name] = fn
266-
return fn
267-
268-
269253
def _is_out_var_node(node: torch.fx.Node) -> bool:
270254
return (
271255
node.op == "call_function"
@@ -561,7 +545,6 @@ def get_node_tensor_specs(
561545
]
562546

563547

564-
@register_algo
565548
def greedy(
566549
graph_module: torch.fx.GraphModule,
567550
alignment: int,
@@ -615,7 +598,6 @@ def greedy(
615598
return total_sizes
616599

617600

618-
@register_algo
619601
def naive(
620602
graph_module: torch.fx.GraphModule,
621603
alignment: int,
@@ -656,15 +638,6 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
656638
return bufsizes
657639

658640

659-
def get_algo(algo_name: str) -> Callable[..., List[int]]:
660-
if algo_name not in REGISTERED_ALGOS:
661-
raise ExportError(
662-
ExportErrorType.NOT_SUPPORTED,
663-
f"Memory planning algorithm '{algo_name}' not found",
664-
)
665-
return REGISTERED_ALGOS[algo_name]
666-
667-
668641
def get_cond_nodes(graph_module: torch.fx.GraphModule) -> Iterable[Node]:
669642
for nd in graph_module.graph.nodes:
670643
if nd.target is torch.ops.higher_order.cond:

exir/passes/memory_planning_pass.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66

77
import logging
88
import warnings
9-
from typing import Optional
9+
from typing import Callable, List, Optional
1010

1111
import torch
1212
from executorch.exir.error import internal_assert
1313
from executorch.exir.memory import alloc
1414
from executorch.exir.memory_planning import (
1515
_is_out_var_node,
1616
apply_algo,
17-
get_algo,
1817
get_node_tensor_specs,
18+
greedy,
1919
Verifier,
2020
)
2121
from executorch.exir.operator.convert import get_out_args_from_opoverload
@@ -27,7 +27,7 @@
2727
class MemoryPlanningPass(PassBase):
2828
def __init__(
2929
self,
30-
memory_planning_algo: str = "greedy",
30+
memory_planning_algo: Callable[..., List[int]] = greedy,
3131
allow_lifetime_and_storage_overlap: bool = False,
3232
alloc_graph_input: bool = True,
3333
alloc_graph_output: bool = True,
@@ -96,14 +96,13 @@ def run(
9696
memory_planning_algo
9797
"""
9898
self._set_alloc_node_spec(graph_module)
99-
algo = get_algo(self.memory_planning_algo)
10099
# TODO(shunting) if people have concern of adding a field to GraphModule
101100
# directly, we should define a GraphModule subclass that we can add our
102101
# customized fields. Using the graph_module object to convey information across
103102
# passes/stages is quite natural and avoid yet another 'context' data structure
104103
# to do the job.
105104
_ = apply_algo(
106-
algo,
105+
self.memory_planning_algo,
107106
graph_module,
108107
self.alignment,
109108
graph_signature,
@@ -125,7 +124,7 @@ def run(
125124
self.allow_lifetime_and_storage_overlap
126125
)
127126
logging.debug(
128-
f"The {self.memory_planning_algo} algorithm reuses storage for {num_reuse_pairs} pair of tensors"
127+
f"The {getattr(self.memory_planning_algo, '__name__', repr(self.memory_planning_algo))} algorithm reuses storage for {num_reuse_pairs} pair of tensors"
129128
)
130129
verifier.verify_graph_input_output()
131130
return PassResult(graph_module, True)

exir/program/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ python_library(
2222
"//caffe2:torch",
2323
"//executorch/exir:error",
2424
"//executorch/exir:graph_module",
25+
"//executorch/exir:pass_base",
2526
"//executorch/exir:pass_manager",
2627
"//executorch/exir:print_program",
2728
"//executorch/exir:schema",

exir/program/test/test_program.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,10 @@ def test_executorch_manager_multi_config(self):
250250
def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:
251251
return {
252252
"forward": MemoryPlanningPass(
253-
memory_planning_algo="greedy",
254253
alloc_graph_input=True,
255254
alloc_graph_output=False,
256255
),
257256
"foo": MemoryPlanningPass(
258-
memory_planning_algo="greedy",
259257
alloc_graph_input=False,
260258
alloc_graph_output=True,
261259
),

exir/tests/test_memory_planning.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
from executorch.exir.memory_planning import (
1818
filter_nodes,
1919
get_node_tensor_specs,
20+
greedy,
21+
naive,
2022
Verifier,
2123
)
2224
from executorch.exir.pass_base import PassResult
@@ -208,7 +210,7 @@ def forward(self, a: torch.Tensor) -> torch.Tensor:
208210

209211
def maketest(
210212
module_cls: Type[torch.nn.Module],
211-
criteria: Optional[List[Tuple[str, bool]]] = None,
213+
criteria: Optional[List[Tuple[Callable[..., List[int]], bool]]] = None,
212214
extra_check: Optional[Callable[..., None]] = None,
213215
use_functionalization: bool = True,
214216
alloc_graph_input: bool = True,
@@ -222,13 +224,15 @@ def wrapper(self: "TestMemoryPlanning") -> None:
222224
if not criteria:
223225
criteria = [
224226
# naive algorithm does not reuse tensor storages
225-
("naive", False),
227+
(naive, False),
226228
# greedy algorithm should reuse tensor storages in the testing model
227-
("greedy", True),
229+
(greedy, True),
228230
]
229231

230232
for algo, expect_reuse in criteria:
231-
print(f"algo {algo}, expect_reuse {expect_reuse}")
233+
print(
234+
f"algo {getattr(algo, '__name__', repr(algo))}, expect_reuse {expect_reuse}"
235+
)
232236
eager_module = module_cls().eval()
233237
inputs = eager_module.get_random_inputs()
234238
graph_module = (
@@ -353,8 +357,8 @@ def verify_overlap_placeholders(
353357
test_return_two: Callable[..., None] = maketest(
354358
ModuleReturnTwo,
355359
criteria=[
356-
("naive", False),
357-
("greedy", True),
360+
(naive, False),
361+
(greedy, True),
358362
],
359363
)
360364

@@ -363,8 +367,8 @@ def verify_overlap_placeholders(
363367
test_list_arg: Callable[..., None] = maketest(
364368
ModuleListArg,
365369
criteria=[
366-
("naive", False),
367-
("greedy", True),
370+
(naive, False),
371+
(greedy, True),
368372
],
369373
extra_check=ModuleListArg.extra_check,
370374
)
@@ -466,20 +470,20 @@ def quantize(self, eager_model: nn.Module) -> nn.Module:
466470
@parameterized.expand(
467471
[
468472
(
469-
"naive",
473+
naive,
470474
[(1, 0), (3, 0), (1, 4), (3, 4), (1, 8)],
471475
[0, 12, 0, 8],
472476
),
473477
(
474-
"greedy",
478+
greedy,
475479
[(1, 0), (3, 0), (1, 4), (3, 4), (1, 0)],
476480
[0, 8, 0, 8],
477481
),
478482
]
479483
)
480484
def test_multiple_pools(
481485
self,
482-
algo: str,
486+
algo: Callable[..., List[int]],
483487
expected_allocs: List[Tuple[int, int]],
484488
expected_bufsizes: List[int],
485489
) -> None:
@@ -550,9 +554,7 @@ def count_planned_inputs(
550554

551555
ep_no_input_planning = to_edge(export(model, inputs)).to_executorch(
552556
config=ExecutorchBackendConfig(
553-
memory_planning_pass=MemoryPlanningPass(
554-
"greedy", alloc_graph_input=False
555-
),
557+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=False),
556558
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
557559
)
558560
)
@@ -572,9 +574,7 @@ def count_planned_inputs(
572574

573575
ep_input_planning = to_edge(export(model, inputs)).to_executorch(
574576
config=ExecutorchBackendConfig(
575-
memory_planning_pass=MemoryPlanningPass(
576-
"greedy", alloc_graph_input=True
577-
),
577+
memory_planning_pass=MemoryPlanningPass(alloc_graph_input=True),
578578
sym_shape_eval_pass=ConstraintBasedSymShapeEvalPass(),
579579
)
580580
)

exir/tests/test_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -713,7 +713,7 @@ def test_alloc_node_spec(self) -> None:
713713
self.assertIsNotNone(new_gm_res)
714714
new_gm = new_gm_res.graph_module
715715

716-
new_gm_res = MemoryPlanningPass("greedy")(new_gm)
716+
new_gm_res = MemoryPlanningPass()(new_gm)
717717
self.assertIsNotNone(new_gm_res)
718718
new_gm = new_gm_res.graph_module
719719

0 commit comments

Comments
 (0)