Skip to content

Commit 7c661d7

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
Add XOR Model Example (#5397)
Summary: Pull Request resolved: #5397 Add a real basic model/use-case to showcase the full training loop from model definition to optimizer.step() Reviewed By: iseeyuan Differential Revision: D62771102 fbshipit-source-id: 48dc01f680085e3192aa4b91396f80ca646d1640
1 parent 8a8e876 commit 7c661d7

File tree

6 files changed

+278
-0
lines changed

6 files changed

+278
-0
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load("@fbcode_macros//build_defs:python_binary.bzl", "python_binary")
5+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
6+
load(":targets.bzl", "define_common_targets")
7+
8+
oncall("executorch")
9+
10+
define_common_targets()
11+
12+
python_library(
13+
name = "model",
14+
srcs = ["model.py"],
15+
visibility = [], # Private
16+
deps = [
17+
"//caffe2:torch",
18+
],
19+
)
20+
21+
python_library(
22+
name = "export_model_lib",
23+
srcs = ["export_model_lib.py"],
24+
visibility = [],
25+
deps = [
26+
":model",
27+
"//caffe2:torch",
28+
"//executorch/exir:lib",
29+
],
30+
)
31+
32+
python_binary(
33+
name = "export_model",
34+
main_function = ".export_model.main",
35+
main_src = "export_model.py",
36+
deps = [
37+
":export_model_lib",
38+
"//caffe2:torch",
39+
],
40+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import argparse
10+
11+
import torch
12+
13+
from .export_model_lib import export_model
14+
15+
16+
def main() -> None:
17+
torch.manual_seed(0)
18+
parser = argparse.ArgumentParser(
19+
prog="export_model",
20+
description="Exports an nn.Module model to ExecuTorch .pte files",
21+
)
22+
parser.add_argument(
23+
"--outdir",
24+
type=str,
25+
required=True,
26+
help="Path to the directory to write xor.pte files to",
27+
)
28+
args = parser.parse_args()
29+
export_model(args.outdir)
30+
31+
32+
if __name__ == "__main__":
33+
main()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import os
10+
11+
import torch
12+
from executorch.exir import to_edge
13+
from executorch.extension.training.examples.XOR.model import TrainingNet
14+
from torch.export._trace import _export
15+
from torch.export.experimental import _export_forward_backward
16+
17+
from .model import Net
18+
19+
20+
def export_model(outdir):
21+
net = TrainingNet(Net())
22+
x = torch.randn(1, 2)
23+
24+
# Captures the forward graph. The graph will look similar to the model definition now.
25+
# Will move to export_for_training soon which is the api planned to be supported in the long term.
26+
ep = _export(net, (x, torch.ones(1, dtype=torch.int64)), pre_dispatch=True)
27+
# Captures the backward graph. The exported_program now contains the joint forward and backward graph.
28+
ep = _export_forward_backward(ep)
29+
# Lower the graph to edge dialect.
30+
ep = to_edge(ep)
31+
# Lower the graph to executorch.
32+
ep = ep.to_executorch()
33+
34+
# Write out the .pte file.
35+
os.makedirs(outdir, exist_ok=True)
36+
outfile = os.path.join(outdir, "xor.pte")
37+
with open(outfile, "wb") as fp:
38+
fp.write(
39+
ep.buffer,
40+
)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch.nn as nn
10+
from torch.nn import functional as F
11+
12+
13+
# Basic Net for XOR
14+
class Net(nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.linear = nn.Linear(2, 10)
18+
self.linear2 = nn.Linear(10, 2)
19+
20+
def forward(self, x):
21+
return self.linear2(F.sigmoid(self.linear(x)))
22+
23+
24+
# On device training requires the loss to be embedded in the model (and be the first output).
25+
# We wrap the original model here and add the loss calculation. This will be the model we export.
26+
class TrainingNet(nn.Module):
27+
def __init__(self, net):
28+
super().__init__()
29+
self.net = net
30+
self.loss = nn.CrossEntropyLoss()
31+
32+
def forward(self, input, label):
33+
pred = self.net(input)
34+
return self.loss(pred, label), pred.detach().argmax(dim=1)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
runtime.cxx_binary(
11+
name = "train_xor",
12+
srcs = ["train.cpp"],
13+
deps = [
14+
"//executorch/extension/training/module:training_module",
15+
"//executorch/extension/tensor:tensor",
16+
"//executorch/extension/training/optimizer:sgd",
17+
"//executorch/runtime/executor:program",
18+
"//executorch/extension/data_loader:file_data_loader",
19+
"//executorch/kernels/portable:generated_lib",
20+
],
21+
external_deps = ["gflags"],
22+
define_static_target = True,
23+
)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/tensor/tensor.h>
11+
#include <executorch/extension/training/module/training_module.h>
12+
#include <executorch/extension/training/optimizer/sgd.h>
13+
#include <gflags/gflags.h>
14+
#include <random>
15+
16+
#pragma clang diagnostic ignored \
17+
"-Wbraced-scalar-init" // {0} below upsets clang.
18+
19+
using executorch::extension::FileDataLoader;
20+
using executorch::extension::training::optimizer::SGD;
21+
using executorch::extension::training::optimizer::SGDOptions;
22+
using executorch::runtime::Error;
23+
using executorch::runtime::Result;
24+
DEFINE_string(model_path, "xor.pte", "Model serialized in flatbuffer format.");
25+
26+
int main(int argc, char** argv) {
27+
gflags::ParseCommandLineFlags(&argc, &argv, true);
28+
if (argc != 1) {
29+
std::string msg = "Extra commandline args: ";
30+
for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) {
31+
msg += argv[i];
32+
}
33+
ET_LOG(Error, "%s", msg.c_str());
34+
return 1;
35+
}
36+
37+
// Load the model file.
38+
executorch::runtime::Result<executorch::extension::FileDataLoader>
39+
loader_res =
40+
executorch::extension::FileDataLoader::from(FLAGS_model_path.c_str());
41+
if (loader_res.error() != Error::Ok) {
42+
ET_LOG(Error, "Failed to open model file: %s", FLAGS_model_path.c_str());
43+
return 1;
44+
}
45+
auto loader = std::make_unique<executorch::extension::FileDataLoader>(
46+
std::move(loader_res.get()));
47+
48+
auto mod = executorch::extension::training::TrainingModule(std::move(loader));
49+
50+
// Create full data set of input and labels.
51+
std::vector<std::pair<
52+
executorch::extension::TensorPtr,
53+
executorch::extension::TensorPtr>>
54+
data_set;
55+
data_set.push_back( // XOR(1, 1) = 0
56+
{executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 1}),
57+
executorch::extension::make_tensor_ptr<long>({1}, {0})});
58+
data_set.push_back( // XOR(0, 0) = 0
59+
{executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 0}),
60+
executorch::extension::make_tensor_ptr<long>({1}, {0})});
61+
data_set.push_back( // XOR(1, 0) = 1
62+
{executorch::extension::make_tensor_ptr<float>({1, 2}, {1, 0}),
63+
executorch::extension::make_tensor_ptr<long>({1}, {1})});
64+
data_set.push_back( // XOR(0, 1) = 1
65+
{executorch::extension::make_tensor_ptr<float>({1, 2}, {0, 1}),
66+
executorch::extension::make_tensor_ptr<long>({1}, {1})});
67+
68+
// Create optimizer.
69+
// Get the params and names
70+
auto param_res = mod.named_parameters("forward");
71+
if (param_res.error() != Error::Ok) {
72+
ET_LOG(Error, "Failed to get named parameters");
73+
return 1;
74+
}
75+
76+
SGDOptions options{0.1};
77+
SGD optimizer(param_res.get(), options);
78+
79+
// Randomness to sample the data set.
80+
std::default_random_engine URBG{std::random_device{}()};
81+
std::uniform_int_distribution<int> dist{
82+
0, static_cast<int>(data_set.size()) - 1};
83+
84+
// Train the model.
85+
size_t num_epochs = 5000;
86+
for (int i = 0; i < num_epochs; i++) {
87+
int index = dist(URBG);
88+
auto& data = data_set[index];
89+
const auto& results = mod.execute_forward_backward(
90+
"forward", {*data.first.get(), *data.second.get()});
91+
if (results.error() != Error::Ok) {
92+
ET_LOG(Error, "Failed to execute forward_backward");
93+
return 1;
94+
}
95+
if (i % 500 == 0 || i == num_epochs - 1) {
96+
ET_LOG(
97+
Info,
98+
"Step %d, Loss %f, Input [%.0f, %.0f], Prediction %ld, Label %ld",
99+
i,
100+
results.get()[0].toTensor().const_data_ptr<float>()[0],
101+
data.first->const_data_ptr<float>()[0],
102+
data.first->const_data_ptr<float>()[1],
103+
results.get()[1].toTensor().const_data_ptr<int64_t>()[0],
104+
data.second->const_data_ptr<int64_t>()[0]);
105+
}
106+
optimizer.step(mod.named_gradients("forward").get());
107+
}
108+
}

0 commit comments

Comments
 (0)