Skip to content

Commit 6ec5342

Browse files
TrainingModule
Differential Revision: D62140852 Pull Request resolved: #5077
1 parent 1a4cf51 commit 6ec5342

File tree

8 files changed

+425
-1
lines changed

8 files changed

+425
-1
lines changed

extension/module/module.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace extension {
2222
/**
2323
* A facade class for loading programs and executing methods within them.
2424
*/
25-
class Module final {
25+
class Module {
2626
public:
2727
/**
2828
* Enum to define loading behavior.
@@ -337,6 +337,8 @@ class Module final {
337337
std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
338338
std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
339339
std::unique_ptr<runtime::EventTracer> event_tracer_;
340+
341+
protected:
340342
std::unordered_map<std::string, MethodHolder> methods_;
341343

342344
friend class ExecuTorchJni;

extension/training/module/TARGETS

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets()

extension/training/module/targets.bzl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets():
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
for aten_mode in (True, False):
11+
aten_suffix = ("_aten" if aten_mode else "")
12+
13+
runtime.cxx_library(
14+
name = "training_module" + aten_suffix,
15+
srcs = [
16+
"training_module.cpp",
17+
],
18+
exported_headers = [
19+
"training_module.h",
20+
],
21+
visibility = [
22+
"@EXECUTORCH_CLIENTS",
23+
],
24+
exported_deps = [
25+
"//executorch/extension/module:module" + aten_suffix,
26+
"//executorch/runtime/core:evalue" + aten_suffix,
27+
],
28+
)
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# Any targets that should be shared between fbcode and xplat must be defined in
2+
# targets.bzl. This file can contain fbcode-only targets.
3+
4+
load(":targets.bzl", "define_common_targets")
5+
6+
oncall("executorch")
7+
8+
define_common_targets(is_fbcode = True)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
3+
def define_common_targets(is_fbcode = False):
4+
"""Defines targets that should be shared between fbcode and xplat.
5+
6+
The directory containing this targets.bzl file should also contain both
7+
TARGETS and BUCK files that call this function.
8+
"""
9+
10+
# TODO(dbort): Find a way to make these run for ANDROID/APPLE in xplat. The
11+
# android and ios test determinators don't like the reference to the model
12+
# file in fbcode. See https://fburl.com/9esapdmd
13+
if not runtime.is_oss and is_fbcode:
14+
modules_env = {
15+
# The tests use this var to find the program file to load. This uses
16+
# an fbcode target path because the authoring/export tools
17+
# intentionally don't work in xplat (since they're host-only tools).
18+
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
19+
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
20+
}
21+
22+
runtime.cxx_test(
23+
name = "training_module_test",
24+
srcs = [
25+
"training_module_test.cpp",
26+
],
27+
deps = [
28+
"//executorch/extension/training/module:training_module",
29+
"//executorch/extension/data_loader:file_data_loader",
30+
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
31+
"//executorch/kernels/portable:generated_lib",
32+
],
33+
env = modules_env,
34+
)
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/data_loader/file_data_loader.h>
10+
#include <executorch/extension/training/module/training_module.h>
11+
12+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
13+
#include <executorch/runtime/platform/runtime.h>
14+
#include <gtest/gtest.h>
15+
16+
// @lint-ignore-every CLANGTIDY facebook-hte-CArray
17+
18+
using namespace ::testing;
19+
using exec_aten::ScalarType;
20+
using exec_aten::Tensor;
21+
using torch::executor::Error;
22+
using torch::executor::Span;
23+
using torch::executor::testing::TensorFactory;
24+
25+
class TrainingModuleTest : public ::testing::Test {
26+
protected:
27+
void SetUp() override {
28+
torch::executor::runtime_init();
29+
}
30+
};
31+
32+
TEST_F(TrainingModuleTest, JointGraphTest) {
33+
// Create a loader for the serialized ModuleAdd program.
34+
const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
35+
executorch::runtime::Result<torch::executor::util::FileDataLoader>
36+
loader_res = torch::executor::util::FileDataLoader::from(path);
37+
ASSERT_EQ(loader_res.error(), Error::Ok);
38+
auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
39+
std::move(loader_res.get()));
40+
41+
auto mod = executorch::extension::training::TrainingModule(std::move(loader));
42+
43+
TensorFactory<ScalarType::Float> tf;
44+
Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
45+
Tensor label = tf.make({3}, {1.0, 0.0, 0.0});
46+
47+
std::vector<executorch::runtime::EValue> inputs;
48+
inputs.push_back(input);
49+
inputs.push_back(label);
50+
51+
auto res = mod.execute_forward_backward("forward", inputs);
52+
ASSERT_EQ(res.error(), Error::Ok);
53+
ASSERT_EQ(res.get().size(), 1);
54+
55+
// Test Gradients
56+
auto grad_res = mod.named_gradients("forward");
57+
ASSERT_EQ(grad_res.error(), Error::Ok);
58+
auto& grad = grad_res.get();
59+
ASSERT_EQ(grad.size(), 2);
60+
ASSERT_NE(grad.find("linear.weight"), grad.end());
61+
ASSERT_NE(grad.find("linear.bias"), grad.end());
62+
63+
ASSERT_EQ(grad.find("linear.weight")->second.sizes()[0], 3);
64+
ASSERT_EQ(grad.find("linear.weight")->second.sizes()[1], 3);
65+
ASSERT_EQ(grad.find("linear.weight")->second.dim(), 2);
66+
ASSERT_EQ(grad.find("linear.bias")->second.sizes()[0], 3);
67+
ASSERT_EQ(grad.find("linear.bias")->second.dim(), 1);
68+
69+
// Test Parameters
70+
auto param_res = mod.named_parameters("forward");
71+
ASSERT_EQ(param_res.error(), Error::Ok);
72+
auto& param = grad_res.get();
73+
ASSERT_EQ(param.size(), 2);
74+
ASSERT_NE(param.find("linear.weight"), grad.end());
75+
ASSERT_NE(param.find("linear.bias"), grad.end());
76+
77+
ASSERT_EQ(param.find("linear.weight")->second.sizes()[0], 3);
78+
ASSERT_EQ(param.find("linear.weight")->second.sizes()[1], 3);
79+
ASSERT_EQ(param.find("linear.weight")->second.dim(), 2);
80+
ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3);
81+
ASSERT_EQ(param.find("linear.bias")->second.dim(), 1);
82+
}
83+
84+
TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
85+
// Create a loader for the serialized ModuleAdd program.
86+
const char* path = std::getenv("ET_MODULE_ADD_PATH");
87+
executorch::runtime::Result<torch::executor::util::FileDataLoader>
88+
loader_res = torch::executor::util::FileDataLoader::from(path);
89+
ASSERT_EQ(loader_res.error(), Error::Ok);
90+
auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
91+
std::move(loader_res.get()));
92+
93+
auto mod = executorch::extension::training::TrainingModule(std::move(loader));
94+
95+
TensorFactory<ScalarType::Float> tf;
96+
Tensor input = tf.make({2, 2}, {1.0, 1.0, 1.0, 1.0});
97+
Tensor input2 = tf.make({2, 2}, {1.0, 0.0, 0.0, 0.0});
98+
99+
std::vector<executorch::runtime::EValue> inputs;
100+
inputs.push_back(input);
101+
inputs.push_back(input2);
102+
103+
// Non-training module should fail to execute forward/backward as it cant find
104+
// the gradients or params.
105+
auto res = mod.execute_forward_backward("forward", inputs);
106+
ASSERT_EQ(res.error(), Error::InvalidArgument);
107+
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/training/module/training_module.h>
10+
11+
namespace executorch {
12+
namespace extension {
13+
namespace training {
14+
15+
namespace {
16+
std::string gradients_method_prefix = "__et_training_gradients_index_";
17+
std::string parameters_method_prefix = "__et_training_parameters_index_";
18+
std::string fqn_method_prefix = "__et_training_fqn_";
19+
} // namespace
20+
21+
runtime::Result<std::vector<runtime::EValue>>
22+
TrainingModule::execute_forward_backward(
23+
const std::string& method_name,
24+
const std::vector<runtime::EValue>& input) {
25+
// Find where the user outputs end.
26+
const std::string gradients_method_name =
27+
gradients_method_prefix + method_name;
28+
auto res = executorch::extension::Module::execute(gradients_method_name);
29+
if (!res.ok()) {
30+
return res.error();
31+
}
32+
uint64_t grad_start = res.get()[0].toInt();
33+
34+
const std::string parameters_method_name =
35+
parameters_method_prefix + method_name;
36+
// get params start.
37+
auto param_res =
38+
executorch::extension::Module::execute(parameters_method_name);
39+
if (!param_res.ok()) {
40+
return param_res.error();
41+
}
42+
43+
uint64_t param_start = param_res.get()[0].toInt();
44+
45+
// Execute the forward and backward pass.
46+
47+
auto outputs = torch::executor::Module::execute(method_name, input);
48+
if (!outputs.ok()) {
49+
return outputs.error();
50+
}
51+
52+
// Extract the user outputs.
53+
std::vector<runtime::EValue> user_outputs;
54+
user_outputs.reserve(grad_start);
55+
for (size_t i = 0; i < grad_start; ++i) {
56+
user_outputs.push_back(outputs.get().at(i));
57+
}
58+
59+
// Extract and store the gradients.
60+
if (method_named_gradients_.find(method_name) ==
61+
method_named_gradients_.end()) {
62+
method_named_gradients_.insert({method_name, {}});
63+
64+
auto& gradients_map = method_named_gradients_.at(method_name);
65+
// Get names.
66+
const std::string fqn_method_name = fqn_method_prefix + method_name;
67+
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
68+
if (!fqn_res.ok()) {
69+
return fqn_res.error();
70+
}
71+
const auto& fqn_list = fqn_res.get();
72+
73+
// Only have to initialize the dict once because the tensors in the dict and
74+
// the tensors in the method alias the same TensorImpl, so updating one will
75+
// update the other.
76+
size_t name_index = 0;
77+
for (size_t grad_index = grad_start; grad_index < param_start;
78+
++grad_index, ++name_index) {
79+
exec_aten::string_view fqn = fqn_list.at(name_index).toString();
80+
gradients_map.insert({fqn, outputs.get().at(grad_index).toTensor()});
81+
}
82+
}
83+
84+
return user_outputs;
85+
}
86+
87+
runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
88+
TrainingModule::named_parameters(const std::string& method_name) {
89+
std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
90+
const std::string fqn_method_name = fqn_method_prefix + method_name;
91+
const std::string parameters_method_name =
92+
parameters_method_prefix + method_name;
93+
94+
// get names.
95+
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
96+
if (!fqn_res.ok()) {
97+
return fqn_res.error();
98+
}
99+
const auto& fqn_list = fqn_res.get();
100+
101+
// get params start.
102+
auto param_res =
103+
executorch::extension::Module::execute(parameters_method_name);
104+
if (!param_res.ok()) {
105+
return param_res.error();
106+
}
107+
108+
uint64_t param_start = param_res.get()[0].toInt();
109+
110+
auto& method = methods_.at(method_name).method;
111+
112+
// create dict
113+
size_t name_index = 0;
114+
for (size_t param_index = param_start; param_index < method->outputs_size();
115+
++param_index, ++name_index) {
116+
exec_aten::string_view fqn = fqn_list.at(name_index).toString();
117+
exec_aten::Tensor param = method->get_output(param_index).toTensor();
118+
named_parameters.insert({fqn, param});
119+
}
120+
return named_parameters;
121+
}
122+
123+
runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
124+
TrainingModule::named_gradients(const std::string& method_name) {
125+
if (method_named_gradients_.find(method_name) ==
126+
method_named_gradients_.end()) {
127+
ET_LOG(Error, "No gradients found for method %s", method_name.c_str());
128+
return executorch::runtime::Error::InvalidArgument;
129+
}
130+
return method_named_gradients_.at(method_name);
131+
}
132+
133+
} // namespace training
134+
} // namespace extension
135+
} // namespace executorch

0 commit comments

Comments
 (0)