Skip to content

Commit c9d7b6e

Browse files
authored
Fix placeholder lifetime bug in memory planning
Differential Revision: D66184849 Pull Request resolved: #6971
1 parent ddec0c7 commit c9d7b6e

File tree

2 files changed

+72
-4
lines changed

2 files changed

+72
-4
lines changed

exir/memory_planning.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,9 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
268268
)
269269

270270

271-
def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
271+
def update_tensor_lifetime(
272+
node: torch.fx.Node, spec: TensorSpec, node_idx: int
273+
) -> None:
272274
r"""
273275
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
274276
are represented by the index of the first and last node referring
@@ -279,7 +281,10 @@ def update_tensor_lifetime(spec: TensorSpec, node_idx: int) -> None:
279281
node_idx: extend the tensor's lifetime to cover node_idx
280282
"""
281283
start, end = spec.lifetime
282-
start = node_idx if start is None or start > node_idx else start
284+
if node.op == "placeholder":
285+
start = 0
286+
else:
287+
start = node_idx if start is None or start > node_idx else start
283288
end = node_idx if end is None or end < node_idx else end
284289
spec.lifetime = [start, end]
285290

@@ -444,7 +449,7 @@ def update_all_tensors_lifetime(
444449
do_assertion=False,
445450
ignore_dynamic_unbound_tensor=False,
446451
):
447-
update_tensor_lifetime(spec, node_idx)
452+
update_tensor_lifetime(node, spec, node_idx)
448453
specs.add(spec)
449454
return specs
450455

exir/tests/test_memory_planning.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
import torch
1616
from executorch.exir import ExecutorchBackendConfig, to_edge
17+
from executorch.exir.dialects._ops import ops as exir_ops
1718
from executorch.exir.memory_planning import (
1819
filter_nodes,
1920
get_node_tensor_specs,
2021
greedy,
2122
naive,
2223
Verifier,
2324
)
24-
from executorch.exir.pass_base import PassResult
25+
from executorch.exir.pass_base import ExportPass, PassResult
2526
from executorch.exir.pass_manager import PassManager
2627
from executorch.exir.passes import ( # noqa
2728
MemoryPlanningPass,
@@ -593,3 +594,65 @@ def count_planned_inputs(
593594
num_placeholders,
594595
5,
595596
)
597+
598+
def test_placeholder_lifetime(self) -> None:
599+
class TestModel(torch.nn.Module):
600+
def __init__(self) -> None:
601+
super().__init__()
602+
self.linear = torch.nn.Linear(5, 5)
603+
604+
def forward(self, a, b, x):
605+
a = a + b
606+
b = a + b
607+
y = self.linear(x)
608+
return a, b, y
609+
610+
model = TestModel()
611+
example_inputs = (torch.rand(1, 6, 2), torch.rand(1, 6, 2), torch.randn(5, 5))
612+
exported_model = torch.export.export(model, example_inputs)
613+
edge = to_edge(exported_model)
614+
615+
class TestPass(ExportPass):
616+
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
617+
permute_dims = [1, 0, 2]
618+
for node in graph_module.graph.nodes:
619+
if node.op == "placeholder" and str(node) == "a":
620+
inverse_dims = [
621+
permute_dims.index(x) for x in range(len(permute_dims))
622+
]
623+
624+
with graph_module.graph.inserting_after(node):
625+
permute = graph_module.graph.call_function(
626+
exir_ops.edge.aten.permute_copy.default,
627+
args=(node, inverse_dims),
628+
)
629+
permute.meta = node.meta.copy()
630+
node.meta["val"] = node.meta["val"].permute(permute_dims)
631+
node.replace_all_uses_with(
632+
permute, lambda x, permute=permute: x is not permute
633+
)
634+
break
635+
return PassResult(graph_module, True)
636+
637+
edge = edge.transform([TestPass()])
638+
et = edge.to_executorch()
639+
et_program = et.executorch_program
640+
inputs = et_program.execution_plan[0].inputs
641+
self.assertNotEqual(
642+
et_program.execution_plan[0] # pyre-ignore
643+
.values[inputs[0]]
644+
.val.allocation_info.memory_offset_low,
645+
et_program.execution_plan[0] # pyre-ignore
646+
.values[inputs[1]]
647+
.val.allocation_info.memory_offset_low,
648+
)
649+
650+
constants = 0
651+
for node in et.exported_program().graph_module.graph.nodes:
652+
if node.op == "placeholder" and node.meta.get("spec"):
653+
meta_spec = node.meta["spec"]
654+
if meta_spec.const is True:
655+
constants += 1
656+
self.assertIsNone(node.meta["spec"].mem_offset)
657+
self.assertIsNone(node.meta["spec"].mem_id)
658+
self.assertEqual(constants, 2)

0 commit comments

Comments
 (0)