Skip to content

Commit 6fb6332

Browse files
fix memory planning not skipping None values
Differential Revision: D69211071 Pull Request resolved: #8276
1 parent d01810e commit 6fb6332

File tree

2 files changed

+51
-1
lines changed

2 files changed

+51
-1
lines changed

exir/passes/memory_planning_pass.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,16 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
7777
out_alloc_node.meta["spec"] = node.meta["spec"]
7878
continue
7979
specs = get_node_tensor_specs(node)
80-
for i, out_arg in enumerate(out_arg_names):
80+
i = 0
81+
for out_arg in out_arg_names:
8182
out_alloc_node = node.kwargs[out_arg]
8283
if out_alloc_node is None:
8384
warnings.warn(
8485
f"Function {node.target}'s {out_arg} kwarg value is None",
8586
stacklevel=1,
8687
)
8788
continue
89+
# dont increment i as we dont have a spec for this node
8890
internal_assert(
8991
out_alloc_node.op == "call_function"
9092
and out_alloc_node.target == alloc,
@@ -95,6 +97,7 @@ def _set_alloc_node_spec(self, graph_module: torch.fx.GraphModule) -> None:
9597
f"Out-var's allocation node {out_alloc_node} already has a spec assigned",
9698
)
9799
out_alloc_node.meta["spec"] = specs[i]
100+
i += 1
98101

99102
@deprecated(
100103
"MemoryPlanningPass.call() is deprecated as it does not handle graphs \

exir/tests/test_memory_planning.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
prepare_fx,
5050
)
5151
from torch.export import export
52+
from torch.export.experimental import _export_forward_backward
5253
from torch.export.exported_program import ExportGraphSignature
5354
from torch.fx import Graph, GraphModule, Node
5455
from torch.nn import functional as F
@@ -724,3 +725,49 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
724725
self.assertIsNone(node.meta["spec"].mem_offset)
725726
self.assertIsNone(node.meta["spec"].mem_id)
726727
self.assertEqual(constants, 2)
728+
729+
def test_none_output(self) -> None:
730+
class Net(nn.Module):
731+
def __init__(self):
732+
super().__init__()
733+
self.conv1 = nn.Conv2d(6, 6, 5)
734+
self.linear = nn.Linear(6, 2)
735+
736+
def forward(self, x):
737+
return self.linear(self.conv1(x).flatten(1))
738+
739+
class TrainingNet(nn.Module):
740+
def __init__(self, net):
741+
super().__init__()
742+
self.net = net
743+
self.loss = nn.CrossEntropyLoss()
744+
745+
def forward(self, input, label):
746+
pred = self.net(input)
747+
return self.loss(pred, label)
748+
749+
net = TrainingNet(Net())
750+
inputs = (torch.randn(1, 6, 5, 5), torch.ones(1, dtype=torch.int64))
751+
752+
ep = export(net, inputs)
753+
ep = _export_forward_backward(ep)
754+
ep = to_edge(ep)
755+
ep = ep.to_executorch()
756+
757+
ep.dump_executorch_program(True)
758+
759+
# 155 just so happens to be the index of the user_grad output arg of
760+
# convolution_backward.out. This is fairly fragile.
761+
# Check that the None output is not memory planned.
762+
self.assertEqual(
763+
ep.executorch_program.execution_plan[0]
764+
.values[155]
765+
.val.data_buffer_idx, # pyright: ignore
766+
0,
767+
)
768+
self.assertEqual(
769+
ep.executorch_program.execution_plan[0]
770+
.values[155]
771+
.val.allocation_info, # pyright: ignore
772+
None,
773+
)

0 commit comments

Comments
 (0)