Skip to content

Commit 263a1c6

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
XOR model export in CI
Summary: Add model to CI since it got broken recently Differential Revision: D65768331
1 parent bec0625 commit 263a1c6

File tree

4 files changed

+54
-13
lines changed

4 files changed

+54
-13
lines changed

extension/training/examples/XOR/export_model.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,19 @@
1717
from torch.export import export
1818
from torch.export.experimental import _export_forward_backward
1919

20+
def _export_model():
21+
net = TrainingNet(Net())
22+
x = torch.randn(1, 2)
23+
24+
# Captures the forward graph. The graph will look similar to the model definition now.
25+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
26+
ep = export(net, (x, torch.ones(1, dtype=torch.int64)))
27+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
28+
ep = _export_forward_backward(ep)
29+
# Lower the graph to edge dialect.
30+
ep = to_edge(ep)
31+
# Lower the graph to executorch.
32+
ep = ep.to_executorch()
2033

2134
def main() -> None:
2235
torch.manual_seed(0)
@@ -32,18 +45,7 @@ def main() -> None:
3245
)
3346
args = parser.parse_args()
3447

35-
net = TrainingNet(Net())
36-
x = torch.randn(1, 2)
37-
38-
# Captures the forward graph. The graph will look similar to the model definition now.
39-
# Will move to export_for_training soon which is the api planned to be supported in the long term.
40-
ep = export(net, (x, torch.ones(1, dtype=torch.int64)))
41-
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
42-
ep = _export_forward_backward(ep)
43-
# Lower the graph to edge dialect.
44-
ep = to_edge(ep)
45-
# Lower the graph to executorch.
46-
ep = ep.to_executorch()
48+
ep = _export_model()
4749

4850
# Write out the .pte file.
4951
os.makedirs(args.outdir, exist_ok=True)

extension/training/examples/XOR/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def define_common_targets():
3434
runtime.python_library(
3535
name = "export_model_lib",
3636
srcs = ["export_model.py"],
37-
visibility = [],
37+
visibility = ["//executorch/extension/training/examples/XOR/..."],
3838
deps = [
3939
":model",
4040
"//caffe2:torch",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_test(
6+
name = "test",
7+
srcs = ["test_export.py"],
8+
visibility = ["//executorch/extension/training/examples/XOR/test/..."],
9+
deps = [
10+
"//caffe2:torch",
11+
"//executorch/exir:lib",
12+
"//executorch/extension/training:lib",
13+
"//executorch/extension/training/examples/XOR:export_model_lib",
14+
],
15+
)
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
# Copyright (c) Meta Platforms, Inc. and affiliates.
10+
# All rights reserved.
11+
#
12+
# This source code is licensed under the BSD-style license found in the
13+
# LICENSE file in the root directory of this source tree.
14+
15+
import unittest
16+
17+
from executorch.extension.training.examples.XOR.export_model import _export_model
18+
19+
20+
class TestXORExport(unittest.TestCase):
21+
def test(self):
22+
_ = _export_model()
23+
# Expect that we reach this far without an exception being thrown.
24+
self.assertTrue(True)

0 commit comments

Comments
 (0)