Skip to content

Commit 55c76a4

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
XOR model export in CI (#6756)
Summary: Add model to CI since it got broken recently Reviewed By: malfet, georgehong Differential Revision: D65768331
1 parent 047fd37 commit 55c76a4

File tree

4 files changed

+50
-13
lines changed

4 files changed

+50
-13
lines changed

extension/training/examples/XOR/export_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@
1818
from torch.export.experimental import _export_forward_backward
1919

2020

21+
def _export_model():
22+
net = TrainingNet(Net())
23+
x = torch.randn(1, 2)
24+
25+
# Captures the forward graph. The graph will look similar to the model definition now.
26+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
27+
ep = export(net, (x, torch.ones(1, dtype=torch.int64)))
28+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
29+
ep = _export_forward_backward(ep)
30+
# Lower the graph to edge dialect.
31+
ep = to_edge(ep)
32+
# Lower the graph to executorch.
33+
ep = ep.to_executorch()
34+
35+
2136
def main() -> None:
2237
torch.manual_seed(0)
2338
parser = argparse.ArgumentParser(
@@ -32,18 +47,7 @@ def main() -> None:
3247
)
3348
args = parser.parse_args()
3449

35-
net = TrainingNet(Net())
36-
x = torch.randn(1, 2)
37-
38-
# Captures the forward graph. The graph will look similar to the model definition now.
39-
# Will move to export_for_training soon which is the api planned to be supported in the long term.
40-
ep = export(net, (x, torch.ones(1, dtype=torch.int64)))
41-
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
42-
ep = _export_forward_backward(ep)
43-
# Lower the graph to edge dialect.
44-
ep = to_edge(ep)
45-
# Lower the graph to executorch.
46-
ep = ep.to_executorch()
50+
ep = _export_model()
4751

4852
# Write out the .pte file.
4953
os.makedirs(args.outdir, exist_ok=True)

extension/training/examples/XOR/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def define_common_targets():
3434
runtime.python_library(
3535
name = "export_model_lib",
3636
srcs = ["export_model.py"],
37-
visibility = [],
37+
visibility = ["//executorch/extension/training/examples/XOR/..."],
3838
deps = [
3939
":model",
4040
"//caffe2:torch",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_test(
6+
name = "test",
7+
srcs = ["test_export.py"],
8+
visibility = ["//executorch/extension/training/examples/XOR/test/..."],
9+
deps = [
10+
"//caffe2:torch",
11+
"//executorch/exir:lib",
12+
"//executorch/extension/training:lib",
13+
"//executorch/extension/training/examples/XOR:export_model_lib",
14+
],
15+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import unittest
10+
11+
from executorch.extension.training.examples.XOR.export_model import _export_model
12+
13+
14+
class TestXORExport(unittest.TestCase):
15+
def test(self):
16+
_ = _export_model()
17+
# Expect that we reach this far without an exception being thrown.
18+
self.assertTrue(True)

0 commit comments

Comments
 (0)