Skip to content

Commit 36abcee

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
XOR model export in CI (#6756)
Summary: Add model to CI since it got broken recently Reviewed By: malfet, georgehong Differential Revision: D65768331
1 parent 71612a6 commit 36abcee

File tree

6 files changed

+56
-16
lines changed

6 files changed

+56
-16
lines changed

extension/training/examples/XOR/export_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,21 @@
1818
from torch.export.experimental import _export_forward_backward
1919

2020

21+
def _export_model():
22+
net = TrainingNet(Net())
23+
x = torch.randn(1, 2)
24+
25+
# Captures the forward graph. The graph will look similar to the model definition now.
26+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
27+
ep = export(net, (x, torch.ones(1, dtype=torch.int64)))
28+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
29+
ep = _export_forward_backward(ep)
30+
# Lower the graph to edge dialect.
31+
ep = to_edge(ep)
32+
# Lower the graph to executorch.
33+
ep = ep.to_executorch()
34+
35+
2136
def main() -> None:
2237
torch.manual_seed(0)
2338
parser = argparse.ArgumentParser(
@@ -32,18 +47,7 @@ def main() -> None:
3247
)
3348
args = parser.parse_args()
3449

35-
net = TrainingNet(Net())
36-
x = torch.randn(1, 2)
37-
38-
# Captures the forward graph. The graph will look similar to the model definition now.
39-
# Will move to export_for_training soon which is the api planned to be supported in the long term.
40-
ep = export(net, (x, torch.ones(1, dtype=torch.int64)))
41-
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
42-
ep = _export_forward_backward(ep)
43-
# Lower the graph to edge dialect.
44-
ep = to_edge(ep)
45-
# Lower the graph to executorch.
46-
ep = ep.to_executorch()
50+
ep = _export_model()
4751

4852
# Write out the .pte file.
4953
os.makedirs(args.outdir, exist_ok=True)

extension/training/examples/XOR/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,6 @@ def __init__(self, net):
3131

3232
def forward(self, input, label):
3333
pred = self.net(input)
34-
return self.loss(pred, label), pred.detach().argmax(dim=1)
34+
return self.loss(
35+
pred, label
36+
) # , pred.detach().argmax(dim=1) TODO(jakeszwe) uncomment this return

extension/training/examples/XOR/targets.bzl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def define_common_targets():
3434
runtime.python_library(
3535
name = "export_model_lib",
3636
srcs = ["export_model.py"],
37-
visibility = [],
37+
visibility = ["//executorch/extension/training/examples/XOR/..."],
3838
deps = [
3939
":model",
4040
"//caffe2:torch",
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
oncall("executorch")
4+
5+
runtime.python_test(
6+
name = "test",
7+
srcs = ["test_export.py"],
8+
visibility = ["//executorch/extension/training/examples/XOR/test/..."],
9+
deps = [
10+
"//caffe2:torch",
11+
"//executorch/exir:lib",
12+
"//executorch/extension/training:lib",
13+
"//executorch/extension/training/examples/XOR:export_model_lib",
14+
],
15+
)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import unittest
10+
11+
from executorch.extension.training.examples.XOR.export_model import _export_model
12+
13+
14+
class TestXORExport(unittest.TestCase):
15+
def test(self):
16+
_ = _export_model()
17+
# Expect that we reach this far without an exception being thrown.
18+
self.assertTrue(True)

extension/training/examples/XOR/train.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,13 @@ int main(int argc, char** argv) {
9595
if (i % 500 == 0 || i == num_epochs - 1) {
9696
ET_LOG(
9797
Info,
98-
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %ld, Label %ld",
98+
"Step %d, Loss %f, Input [%.0f, %.0f], Label %ld",
9999
i,
100100
results.get()[0].toTensor().const_data_ptr<float>()[0],
101101
data.first->const_data_ptr<float>()[0],
102102
data.first->const_data_ptr<float>()[1],
103-
results.get()[1].toTensor().const_data_ptr<int64_t>()[0],
103+
// results.get()[1].toTensor().const_data_ptr<int64_t>()[0],
104+
// TODO(jakeszwe) turn back on non loss output
104105
data.second->const_data_ptr<int64_t>()[0]);
105106
}
106107
optimizer.step(mod.named_gradients("forward").get());

0 commit comments

Comments
 (0)