|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | from executorch.exir import ExecutorchBackendConfig, to_edge
|
| 17 | +from executorch.exir.dialects._ops import ops as exir_ops |
17 | 18 | from executorch.exir.memory_planning import (
|
18 | 19 | filter_nodes,
|
19 | 20 | get_node_tensor_specs,
|
20 | 21 | greedy,
|
21 | 22 | naive,
|
22 | 23 | Verifier,
|
23 | 24 | )
|
24 |
| -from executorch.exir.pass_base import PassResult |
| 25 | +from executorch.exir.pass_base import ExportPass, PassResult |
25 | 26 | from executorch.exir.pass_manager import PassManager
|
26 | 27 | from executorch.exir.passes import ( # noqa
|
27 | 28 | MemoryPlanningPass,
|
@@ -593,3 +594,65 @@ def count_planned_inputs(
|
593 | 594 | num_placeholders,
|
594 | 595 | 5,
|
595 | 596 | )
|
| 597 | + |
| 598 | + def test_placeholder_lifetime(self) -> None: |
| 599 | + class TestModel(torch.nn.Module): |
| 600 | + def __init__(self) -> None: |
| 601 | + super().__init__() |
| 602 | + self.linear = torch.nn.Linear(5, 5) |
| 603 | + |
| 604 | + def forward(self, a, b, x): |
| 605 | + a = a + b |
| 606 | + b = a + b |
| 607 | + y = self.linear(x) |
| 608 | + return a, b, y |
| 609 | + |
| 610 | + model = TestModel() |
| 611 | + example_inputs = (torch.rand(1, 6, 2), torch.rand(1, 6, 2), torch.randn(5, 5)) |
| 612 | + exported_model = torch.export.export(model, example_inputs) |
| 613 | + edge = to_edge(exported_model) |
| 614 | + |
| 615 | + class TestPass(ExportPass): |
| 616 | + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: |
| 617 | + permute_dims = [1, 0, 2] |
| 618 | + for node in graph_module.graph.nodes: |
| 619 | + if node.op == "placeholder" and str(node) == "a": |
| 620 | + inverse_dims = [ |
| 621 | + permute_dims.index(x) for x in range(len(permute_dims)) |
| 622 | + ] |
| 623 | + |
| 624 | + with graph_module.graph.inserting_after(node): |
| 625 | + permute = graph_module.graph.call_function( |
| 626 | + exir_ops.edge.aten.permute_copy.default, |
| 627 | + args=(node, inverse_dims), |
| 628 | + ) |
| 629 | + permute.meta = node.meta.copy() |
| 630 | + node.meta["val"] = node.meta["val"].permute(permute_dims) |
| 631 | + node.replace_all_uses_with( |
| 632 | + permute, lambda x, permute=permute: x is not permute |
| 633 | + ) |
| 634 | + break |
| 635 | + return PassResult(graph_module, True) |
| 636 | + |
| 637 | + edge = edge.transform([TestPass()]) |
| 638 | + et = edge.to_executorch() |
| 639 | + et_program = et.executorch_program |
| 640 | + inputs = et_program.execution_plan[0].inputs |
| 641 | + self.assertNotEqual( |
| 642 | + et_program.execution_plan[0] # pyre-ignore |
| 643 | + .values[inputs[0]] |
| 644 | + .val.allocation_info.memory_offset_low, |
| 645 | + et_program.execution_plan[0] # pyre-ignore |
| 646 | + .values[inputs[1]] |
| 647 | + .val.allocation_info.memory_offset_low, |
| 648 | + ) |
| 649 | + |
| 650 | + constants = 0 |
| 651 | + for node in et.exported_program().graph_module.graph.nodes: |
| 652 | + if node.op == "placeholder" and node.meta.get("spec"): |
| 653 | + meta_spec = node.meta["spec"] |
| 654 | + if meta_spec.const is True: |
| 655 | + constants += 1 |
| 656 | + self.assertIsNone(node.meta["spec"].mem_offset) |
| 657 | + self.assertIsNone(node.meta["spec"].mem_id) |
| 658 | + self.assertEqual(constants, 2) |
0 commit comments