Skip to content

Commit 07e7250

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Mark call as deprecated (#7968)
Summary: call is deprecated since it cant handle mutation. This is a no op for people using the default memory planning stuff today, but want to call out louder to people implementing their own not to do call. Reviewed By: hsharma35 Differential Revision: D68726718
1 parent 5ed191a commit 07e7250

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

exir/passes/memory_planning_pass.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Callable, List, Optional
1010

1111
import torch
12+
from executorch.exir._warnings import deprecated
1213
from executorch.exir.error import internal_assert
1314
from executorch.exir.memory import alloc
1415
from executorch.exir.memory_planning import (
@@ -83,6 +84,11 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
8384
)
8485
out_alloc_node.meta["spec"] = specs[i]
8586

87+
@deprecated(
88+
"MemoryPlanningPass.call() is deprecated as it does not handle graphs \
89+
with mutation, please use MemoryPlanningPass.run() instead",
90+
category=FutureWarning,
91+
)
8692
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
8793
return self.run(graph_module)
8894

exir/tests/test_memory_planning.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,39 @@ def test_multiple_pools(
519519
idx += 1
520520
self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes)
521521

522+
def test_mutation_not_double_allocated(self) -> None:
523+
class Simple(torch.nn.Module):
524+
def __init__(self) -> None:
525+
super().__init__()
526+
self.register_buffer("constant", torch.ones(5, 5))
527+
528+
def forward(self, x: torch.Tensor) -> torch.Tensor:
529+
self.constant.add_(1)
530+
return x - self.constant
531+
532+
model = Simple()
533+
inputs = (torch.ones(5, 5),)
534+
535+
et = to_edge(export(model, inputs, strict=True)).to_executorch()
536+
537+
# 0 and 11 should refer to the same tensor. 0 is the input, 11 is the output of copy_
538+
self.assertEqual(
539+
et.executorch_program.execution_plan[0]
540+
.values[0]
541+
.val.allocation_info.memory_offset_low,
542+
et.executorch_program.execution_plan[0]
543+
.values[11]
544+
.val.allocation_info.memory_offset_low,
545+
)
546+
self.assertEqual(
547+
et.executorch_program.execution_plan[0]
548+
.values[0]
549+
.val.allocation_info.memory_offset_high,
550+
et.executorch_program.execution_plan[0]
551+
.values[11]
552+
.val.allocation_info.memory_offset_high,
553+
)
554+
522555
def test_constants_not_memory_planned(self) -> None:
523556
class Simple(torch.nn.Module):
524557
def __init__(self) -> None:

0 commit comments

Comments
 (0)