Skip to content

Mark call as deprecated #7968

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions exir/passes/memory_planning_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import Callable, List, Optional

import torch
from executorch.exir._warnings import deprecated
from executorch.exir.error import internal_assert
from executorch.exir.memory import alloc
from executorch.exir.memory_planning import (
Expand Down Expand Up @@ -83,6 +84,11 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
)
out_alloc_node.meta["spec"] = specs[i]

@deprecated(
"MemoryPlanningPass.call() is deprecated as it does not handle graphs \
with mutation, please use MemoryPlanningPass.run() instead",
category=FutureWarning,
)
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
return self.run(graph_module)

Expand Down
33 changes: 33 additions & 0 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,39 @@ def test_multiple_pools(
idx += 1
self.assertEqual(graph_module.meta["non_const_buffer_sizes"], expected_bufsizes)

def test_mutation_not_double_allocated(self) -> None:
class Simple(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_buffer("constant", torch.ones(5, 5))

def forward(self, x: torch.Tensor) -> torch.Tensor:
self.constant.add_(1)
return x - self.constant

model = Simple()
inputs = (torch.ones(5, 5),)

et = to_edge(export(model, inputs, strict=True)).to_executorch()

# 0 and 11 should refer to the same tensor. 0 is the input, 11 is the output of copy_
self.assertEqual(
et.executorch_program.execution_plan[0]
.values[0]
.val.allocation_info.memory_offset_low,
et.executorch_program.execution_plan[0]
.values[11]
.val.allocation_info.memory_offset_low,
)
self.assertEqual(
et.executorch_program.execution_plan[0]
.values[0]
.val.allocation_info.memory_offset_high,
et.executorch_program.execution_plan[0]
.values[11]
.val.allocation_info.memory_offset_high,
)

def test_constants_not_memory_planned(self) -> None:
class Simple(torch.nn.Module):
def __init__(self) -> None:
Expand Down
Loading