|
18 | 18 | from executorch.backends.qnnpack.partition.qnnpack_partitioner import QnnpackPartitioner
|
19 | 19 | from executorch.exir.backend.backend_api import to_backend, validation_disabled
|
20 | 20 | from executorch.exir.memory_planning import filter_nodes, Verifier
|
| 21 | +from executorch.exir.pass_base import PassResult |
21 | 22 | from executorch.exir.pass_manager import PassManager
|
22 | 23 | from executorch.exir.passes import ( # noqa
|
23 | 24 | ConstPropPass,
|
|
29 | 30 | )
|
30 | 31 | from executorch.exir.print_program import print_program
|
31 | 32 | from executorch.exir.tests.asr_joiner import ASRJoiner
|
| 33 | +from parameterized import parameterized |
32 | 34 |
|
33 | 35 | from torch import nn
|
34 | 36 | from torch.ao.quantization import ( # @manual=//caffe2:torch
|
@@ -157,6 +159,47 @@ def extra_check(
|
157 | 159 | testcase.assertTrue(getitem_spec.lifetime[1] >= cat_specs[0].lifetime[0])
|
158 | 160 |
|
159 | 161 |
|
| 162 | +class CustomPoolMemoryPlanningPass(MemoryPlanningPass): |
| 163 | + def call(self, graph_module: GraphModule) -> PassResult: |
| 164 | + for subgm in graph_module.modules(): |
| 165 | + if not isinstance(subgm, GraphModule): |
| 166 | + continue |
| 167 | + for node in subgm.graph.nodes: |
| 168 | + # mem_id = 1 placeholder and outputs of mul |
| 169 | + # mem_id = 3 for outputs of add |
| 170 | + # parent class will copy spec will to alloc nodes |
| 171 | + if node.op == "placeholder": |
| 172 | + node.meta["spec"].mem_id = 1 |
| 173 | + continue |
| 174 | + |
| 175 | + if node.op != "call_function": |
| 176 | + continue |
| 177 | + |
| 178 | + if node.target == torch.ops.aten.add.out: |
| 179 | + node.meta["spec"].mem_id = 3 |
| 180 | + elif node.target == torch.ops.aten.mul.out: |
| 181 | + node.meta["spec"].mem_id = 1 |
| 182 | + |
| 183 | + return super().call(graph_module) |
| 184 | + |
| 185 | + |
| 186 | +class MultiplePoolsToyModel(torch.nn.Module): |
| 187 | + def forward(self, a: torch.Tensor) -> torch.Tensor: |
| 188 | + # a: mem_id = 1, offset = 0 |
| 189 | + # b: mem_id = 3, offset = 0 |
| 190 | + # c: mem_id = 1, offset = 4 |
| 191 | + # d: mem_id = 3, offset = 4 |
| 192 | + # greedy: |
| 193 | + # e: mem_id = 1, offset = 0 |
| 194 | + # naive: |
| 195 | + # e: mem_id = 1, offset = 8 |
| 196 | + b = a + a |
| 197 | + c = a * b |
| 198 | + d = c + b |
| 199 | + e = c * d |
| 200 | + return e |
| 201 | + |
| 202 | + |
160 | 203 | def maketest(
|
161 | 204 | module_cls: Type[torch.nn.Module],
|
162 | 205 | criteria: Optional[List[Tuple[str, bool]]] = None,
|
@@ -463,3 +506,60 @@ def test_asr_joiner(self) -> None:
|
463 | 506 | )
|
464 | 507 |
|
465 | 508 | self.assertEqual(3, ncheck)
|
| 509 | + |
| 510 | + # pyre-ignore |
| 511 | + @parameterized.expand( |
| 512 | + [ |
| 513 | + ( |
| 514 | + "naive", |
| 515 | + [(1, 0), (3, 0), (1, 4), (3, 4), (1, 8)], |
| 516 | + [0, 12, 0, 8], |
| 517 | + ), |
| 518 | + ( |
| 519 | + "greedy", |
| 520 | + [(1, 0), (3, 0), (1, 4), (3, 4), (1, 0)], |
| 521 | + [0, 8, 0, 8], |
| 522 | + ), |
| 523 | + ] |
| 524 | + ) |
| 525 | + def test_multiple_pools( |
| 526 | + self, |
| 527 | + algo: str, |
| 528 | + expected_allocs: List[Tuple[int, int]], |
| 529 | + expected_bufsizes: List[int], |
| 530 | + ) -> None: |
| 531 | + edge_program = exir.capture( |
| 532 | + MultiplePoolsToyModel(), |
| 533 | + (torch.ones(1),), |
| 534 | + exir.CaptureConfig(pt2_mode=True), |
| 535 | + ).to_edge(exir.EdgeCompileConfig(_check_ir_validity=False)) |
| 536 | + |
| 537 | + program = edge_program.to_executorch( |
| 538 | + exir.ExecutorchBackendConfig( |
| 539 | + memory_planning_pass=CustomPoolMemoryPlanningPass( |
| 540 | + memory_planning_algo=algo, |
| 541 | + alignment=1, |
| 542 | + ) |
| 543 | + ) |
| 544 | + ) |
| 545 | + graph_module = program.dump_graph_module() |
| 546 | + |
| 547 | + verifier = Verifier( |
| 548 | + graph_module, |
| 549 | + alloc_graph_input=True, |
| 550 | + alloc_graph_output=True, |
| 551 | + ) |
| 552 | + verifier.verify_storage_reuse() |
| 553 | + verifier.verify_graph_input_output() |
| 554 | + |
| 555 | + idx = 0 |
| 556 | + for node in graph_module.graph.nodes: |
| 557 | + if node.op == "placeholder" or ( |
| 558 | + node.op == "call_function" |
| 559 | + and node.target in (torch.ops.aten.add.out, torch.ops.aten.mul.out) |
| 560 | + ): |
| 561 | + mem_id, mem_offset = expected_allocs[idx] |
| 562 | + self.assertEqual(node.meta["spec"].mem_id, mem_id) |
| 563 | + self.assertEqual(node.meta["spec"].mem_offset, mem_offset) |
| 564 | + idx += 1 |
| 565 | + self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes) |
0 commit comments