|
49 | 49 | prepare_fx,
|
50 | 50 | )
|
51 | 51 | from torch.export import export
|
| 52 | +from torch.export.experimental import _export_forward_backward |
52 | 53 | from torch.export.exported_program import ExportGraphSignature
|
53 | 54 | from torch.fx import Graph, GraphModule, Node
|
54 | 55 | from torch.nn import functional as F
|
@@ -724,3 +725,49 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
|
724 | 725 | self.assertIsNone(node.meta["spec"].mem_offset)
|
725 | 726 | self.assertIsNone(node.meta["spec"].mem_id)
|
726 | 727 | self.assertEqual(constants, 2)
|
| 728 | + |
| 729 | + def test_none_output(self) -> None: |
| 730 | + class Net(nn.Module): |
| 731 | + def __init__(self): |
| 732 | + super().__init__() |
| 733 | + self.conv1 = nn.Conv2d(6, 6, 5) |
| 734 | + self.linear = nn.Linear(6, 2) |
| 735 | + |
| 736 | + def forward(self, x): |
| 737 | + return self.linear(self.conv1(x).flatten(1)) |
| 738 | + |
| 739 | + class TrainingNet(nn.Module): |
| 740 | + def __init__(self, net): |
| 741 | + super().__init__() |
| 742 | + self.net = net |
| 743 | + self.loss = nn.CrossEntropyLoss() |
| 744 | + |
| 745 | + def forward(self, input, label): |
| 746 | + pred = self.net(input) |
| 747 | + return self.loss(pred, label) |
| 748 | + |
| 749 | + net = TrainingNet(Net()) |
| 750 | + inputs = (torch.randn(1, 6, 5, 5), torch.ones(1, dtype=torch.int64)) |
| 751 | + |
| 752 | + ep = export(net, inputs) |
| 753 | + ep = _export_forward_backward(ep) |
| 754 | + ep = to_edge(ep) |
| 755 | + ep = ep.to_executorch() |
| 756 | + |
| 757 | + ep.dump_executorch_program(True) |
| 758 | + |
| 759 | + # 155 just so happens to be the index of the user_grad output arg of |
| 760 | + # convolution_backward.out. This is fairly fragile. |
| 761 | + # Check that the None output is not memory planned. |
| 762 | + self.assertEqual( |
| 763 | + ep.executorch_program.execution_plan[0] |
| 764 | + .values[155] |
| 765 | + .val.data_buffer_idx, # pyright: ignore |
| 766 | + 0, |
| 767 | + ) |
| 768 | + self.assertEqual( |
| 769 | + ep.executorch_program.execution_plan[0] |
| 770 | + .values[155] |
| 771 | + .val.allocation_info, # pyright: ignore |
| 772 | + None, |
| 773 | + ) |
0 commit comments