Skip to content

Commit 26f9efa

Browse files
sxufacebook-github-bot
authored andcommitted
Support custom mem_id in MemoryPlanningPass algos
Summary: Allow customizing memory pools while still leveraging the existing memory planning algos. The algorithms still default to mem_id 1. Reviewed By: ydwu4 Differential Revision: D48159257 fbshipit-source-id: 96ddc78f52f42e0de6f3b08feacdad188718c18a
1 parent 5462df7 commit 26f9efa

File tree

3 files changed

+124
-15
lines changed

3 files changed

+124
-15
lines changed

exir/memory_planning.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ def greedy(
464464
alloc_graph_output: bool = True,
465465
) -> List[int]:
466466
spec2obj = {}
467-
shared_objects = []
467+
shared_objects = defaultdict(list)
468468
# Don't do assertion in collect_specs_from_nodes if we have already encountered
469469
# and ignored some to_out_variant errors.
470470
do_assertion = not getattr(graph_module, "encounter_to_out_var_failure", False)
@@ -477,23 +477,29 @@ def greedy(
477477
ignore_graph_input=not alloc_graph_input,
478478
ignore_graph_output=not alloc_graph_output,
479479
):
480-
spec.mem_id = 1
480+
if spec.mem_id is None:
481+
spec.mem_id = 1
481482
spec.realign(alignment)
482-
spec2obj[spec] = pick_shared_obj(shared_objects, spec)
483-
484-
input_total_size = 0
485-
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
486-
input_total_size = bufsizes[1]
483+
spec2obj[spec] = pick_shared_obj(shared_objects[spec.mem_id], spec)
484+
485+
total_sizes = [0] * (max(shared_objects.keys()) + 1)
486+
for mem_id in shared_objects:
487+
input_total_size = 0
488+
if bufsizes := getattr(graph_module, "input_mem_buffer_sizes", None):
489+
if len(bufsizes) > mem_id:
490+
input_total_size = bufsizes[mem_id]
491+
total_sizes[mem_id] = materialize_buffer(
492+
shared_objects[mem_id], input_total_size
493+
)
487494

488495
# Since we now know the number of shared objects we need and the size of
489496
# each shared object, we can assign offset in the memory buffer for each
490497
# shared object.
491-
total_size = materialize_buffer(shared_objects, input_total_size)
492498
for spec, sobj in spec2obj.items():
493499
spec.mem_offset = sobj.offset
494500

495-
logging.debug(f"greedy algorithm returns bufsizes: {total_size}")
496-
return [0, total_size]
501+
logging.debug(f"greedy algorithm returns bufsizes: {total_sizes}")
502+
return total_sizes
497503

498504

499505
@register_algo
@@ -506,10 +512,8 @@ def naive(
506512
# allocate 'allocated' bytes from buffer with id mem_id.
507513
# return the starting offset of the allocated buffer.
508514
def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
509-
internal_assert(
510-
mem_id >= 0 and mem_id < len(bufsizes),
511-
f"Tensor mem_id should be between 0 and {len(bufsizes)}, but it was {mem_id}",
512-
)
515+
if mem_id >= len(bufsizes):
516+
bufsizes.extend([0] * (mem_id - len(bufsizes) + 1))
513517
ret = bufsizes[mem_id]
514518
bufsizes[mem_id] += allocated
515519
return ret
@@ -525,7 +529,8 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
525529
ignore_graph_output=not alloc_graph_output,
526530
):
527531
# assume a single memory layer which has mem_id 1
528-
spec.mem_id = 1
532+
if spec.mem_id is None:
533+
spec.mem_id = 1
529534
# allocate spec.allocated_memory bytes in the buffer
530535
# with the corresponding mem_id
531536
spec.realign(alignment)

exir/tests/TARGETS

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,16 @@ python_unittest(
171171
preload_deps = [
172172
"//executorch/kernels/portable:custom_ops_generated_lib",
173173
],
174+
# Static listing does not support tests generated with parameterized
175+
supports_static_listing = False,
174176
deps = [
177+
"fbsource//third-party/pypi/parameterized:parameterized",
175178
":asr_joiner",
176179
"//caffe2:torch",
177180
"//executorch/backends/qnnpack/partition:qnnpack_partitioner",
178181
"//executorch/exir:lib",
179182
"//executorch/exir:memory_planning",
183+
"//executorch/exir:pass_base",
180184
"//executorch/exir:pass_manager",
181185
"//executorch/exir:print_program",
182186
"//executorch/exir:schema",

exir/tests/test_memory_planning.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from executorch.backends.qnnpack.partition.qnnpack_partitioner import QnnpackPartitioner
1919
from executorch.exir.backend.backend_api import to_backend, validation_disabled
2020
from executorch.exir.memory_planning import filter_nodes, Verifier
21+
from executorch.exir.pass_base import PassResult
2122
from executorch.exir.pass_manager import PassManager
2223
from executorch.exir.passes import ( # noqa
2324
ConstPropPass,
@@ -29,6 +30,7 @@
2930
)
3031
from executorch.exir.print_program import print_program
3132
from executorch.exir.tests.asr_joiner import ASRJoiner
33+
from parameterized import parameterized
3234

3335
from torch import nn
3436
from torch.ao.quantization import ( # @manual=//caffe2:torch
@@ -157,6 +159,47 @@ def extra_check(
157159
testcase.assertTrue(getitem_spec.lifetime[1] >= cat_specs[0].lifetime[0])
158160

159161

162+
class CustomPoolMemoryPlanningPass(MemoryPlanningPass):
163+
def call(self, graph_module: GraphModule) -> PassResult:
164+
for subgm in graph_module.modules():
165+
if not isinstance(subgm, GraphModule):
166+
continue
167+
for node in subgm.graph.nodes:
168+
# mem_id = 1 placeholder and outputs of mul
169+
# mem_id = 3 for outputs of add
170+
# parent class will copy spec will to alloc nodes
171+
if node.op == "placeholder":
172+
node.meta["spec"].mem_id = 1
173+
continue
174+
175+
if node.op != "call_function":
176+
continue
177+
178+
if node.target == torch.ops.aten.add.out:
179+
node.meta["spec"].mem_id = 3
180+
elif node.target == torch.ops.aten.mul.out:
181+
node.meta["spec"].mem_id = 1
182+
183+
return super().call(graph_module)
184+
185+
186+
class MultiplePoolsToyModel(torch.nn.Module):
187+
def forward(self, a: torch.Tensor) -> torch.Tensor:
188+
# a: mem_id = 1, offset = 0
189+
# b: mem_id = 3, offset = 0
190+
# c: mem_id = 1, offset = 4
191+
# d: mem_id = 3, offset = 4
192+
# greedy:
193+
# e: mem_id = 1, offset = 0
194+
# naive:
195+
# e: mem_id = 1, offset = 8
196+
b = a + a
197+
c = a * b
198+
d = c + b
199+
e = c * d
200+
return e
201+
202+
160203
def maketest(
161204
module_cls: Type[torch.nn.Module],
162205
criteria: Optional[List[Tuple[str, bool]]] = None,
@@ -463,3 +506,60 @@ def test_asr_joiner(self) -> None:
463506
)
464507

465508
self.assertEqual(3, ncheck)
509+
510+
# pyre-ignore
511+
@parameterized.expand(
512+
[
513+
(
514+
"naive",
515+
[(1, 0), (3, 0), (1, 4), (3, 4), (1, 8)],
516+
[0, 12, 0, 8],
517+
),
518+
(
519+
"greedy",
520+
[(1, 0), (3, 0), (1, 4), (3, 4), (1, 0)],
521+
[0, 8, 0, 8],
522+
),
523+
]
524+
)
525+
def test_multiple_pools(
526+
self,
527+
algo: str,
528+
expected_allocs: List[Tuple[int, int]],
529+
expected_bufsizes: List[int],
530+
) -> None:
531+
edge_program = exir.capture(
532+
MultiplePoolsToyModel(),
533+
(torch.ones(1),),
534+
exir.CaptureConfig(pt2_mode=True),
535+
).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))
536+
537+
program = edge_program.to_executorch(
538+
exir.ExecutorchBackendConfig(
539+
memory_planning_pass=CustomPoolMemoryPlanningPass(
540+
memory_planning_algo=algo,
541+
alignment=1,
542+
)
543+
)
544+
)
545+
graph_module = program.dump_graph_module()
546+
547+
verifier = Verifier(
548+
graph_module,
549+
alloc_graph_input=True,
550+
alloc_graph_output=True,
551+
)
552+
verifier.verify_storage_reuse()
553+
verifier.verify_graph_input_output()
554+
555+
idx = 0
556+
for node in graph_module.graph.nodes:
557+
if node.op == "placeholder" or (
558+
node.op == "call_function"
559+
and node.target in (torch.ops.aten.add.out, torch.ops.aten.mul.out)
560+
):
561+
mem_id, mem_offset = expected_allocs[idx]
562+
self.assertEqual(node.meta["spec"].mem_id, mem_id)
563+
self.assertEqual(node.meta["spec"].mem_offset, mem_offset)
564+
idx += 1
565+
self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes)

0 commit comments

Comments
 (0)