Skip to content

Commit d44877b

Browse files
David Linfacebook-github-bot
authored andcommitted
Added optimizer implementation (#3699)
Summary: Pull Request resolved: #3699 This adds the optimizer logic, reusing much of the logic from [LiteInterpreter](https://fburl.com/code/t5dqeyje). The main differences being: 1. SGDParamGroup takes in a Span<char*> and a Span<Tensor> which represents named parameters. unlike LI or core PT, portable tensors don't use the autograd framework and we won't be supporting it either. instead, we're likely to use the backwards graph to calculate the gradients of the parameters. in that case, we need a way to map the gradients to its appropriate parameter. We expect that the sizes of the two spans are equal, and the index of a specific parameter is the same in both spans. 2. SGD step takes in a Span<char*> and a Span<Tensor> which represents the named gradients. We use this to match the gradient to the appropriate parameter. Similar to above, we expect that the spans are equal sizes and the index of a gradient data is the same as its parameter name. 3. Uses the out variant operations rather than the inplace or functional variants since those are already implemented. I *believe* since we're only using clone, add (same sized tensor), and mul_scalar, there isn't any harm in overwriting the data. 4. For the momentum buffer, I allocate memory for the underlying data and TensorImpl. This gets cleaned up when the SGD destructor is called. Reviewed By: iseeyuan Differential Revision: D57216865 fbshipit-source-id: 5ab49b6f584debc15976982a2e9eb964515e5c54
1 parent 1343224 commit d44877b

File tree

5 files changed

+390
-15
lines changed

5 files changed

+390
-15
lines changed

extension/training/optimizer/sgd.cpp

Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/extension/training/optimizer/sgd.h>
10+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
11+
12+
#include <executorch/runtime/core/error.h>
13+
#include <executorch/runtime/kernel/kernel_runtime_context.h>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace training {
18+
namespace optimizer {
19+
20+
bool SGDParamGroup::has_options() const {
21+
return options_ != nullptr;
22+
}
23+
24+
SGDOptions& SGDParamGroup::options() {
25+
return *options_.get();
26+
}
27+
28+
const SGDOptions& SGDParamGroup::options() const {
29+
return *options_.get();
30+
}
31+
32+
void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
33+
options_ = std::move(options);
34+
}
35+
36+
Span<const char*> SGDParamGroup::param_names() {
37+
return param_names_;
38+
}
39+
40+
const Span<const char*> SGDParamGroup::param_names() const {
41+
return param_names_;
42+
}
43+
44+
Span<Tensor> SGDParamGroup::param_data() {
45+
return param_data_;
46+
}
47+
48+
const Span<Tensor> SGDParamGroup::param_data() const {
49+
return param_data_;
50+
}
51+
52+
void SGD::add_param_group(const SGDParamGroup& param_group) {
53+
SGDParamGroup param_group_(
54+
param_group.param_names(), param_group.param_data());
55+
if (!param_group.has_options()) {
56+
param_group_.set_options(defaults_->clone());
57+
} else {
58+
param_group_.set_options(param_group.options().clone());
59+
}
60+
param_groups_.emplace_back(std::move(param_group_));
61+
}
62+
63+
Error SGD::step(Span<const char*> gradient_names, Span<Tensor> gradient_data) {
64+
// check that the number of gradient names matches the number of gradients
65+
ET_CHECK_OR_RETURN_ERROR(
66+
gradient_names.size() == gradient_data.size(),
67+
InvalidState,
68+
"Gradient names and gradients must have the same length.");
69+
70+
RuntimeContext context;
71+
for (auto& group : param_groups_) {
72+
auto& options = static_cast<SGDOptions&>(group.options());
73+
auto weight_decay = options.weight_decay();
74+
auto momentum = options.momentum();
75+
auto dampening = options.dampening();
76+
auto nesterov = options.nesterov();
77+
78+
for (int i = 0; i < group.param_names().size(); i++) {
79+
for (int j = 0; j < gradient_names.size(); j++) {
80+
// if param name and gradient name match, run the optimizer step
81+
if (strcmp(group.param_names()[i], gradient_names[j]) == 0) {
82+
auto d_p = gradient_data[j];
83+
auto p = group.param_data()[i];
84+
if (weight_decay != 0) {
85+
// uses weight_decay specified and adds it to the gradient
86+
torch::executor::aten::add_outf(context, d_p, p, weight_decay, d_p);
87+
if (context.failure_state() != Error::Ok) {
88+
return context.failure_state();
89+
}
90+
}
91+
if (momentum != 0) {
92+
Tensor buf(nullptr);
93+
auto param_state = state_.find(p.unsafeGetTensorImpl());
94+
// look for the momentum buffer for the given parameter. this is the
95+
// momentum as of the previous epoch
96+
if (param_state == state_.end()) {
97+
// create a new momentum buffer if it doesn't exist. this memory
98+
// needs to be freed when the optimizer is destroyed
99+
void* buf_ptr = malloc(d_p.nbytes());
100+
101+
#ifdef USE_ATEN_LIB
102+
std::vector<int64_t> sizes(
103+
d_p.sizes().begin(), d_p.sizes().end());
104+
buf = torch::from_blob(buf_ptr, sizes, d_p.scalar_type());
105+
#else
106+
TensorImpl* buf_impl = new TensorImpl(
107+
d_p.scalar_type(),
108+
d_p.sizes().size(),
109+
const_cast<TensorImpl::SizesType*>(d_p.sizes().data()),
110+
buf_ptr,
111+
const_cast<TensorImpl::DimOrderType*>(
112+
d_p.dim_order().data()));
113+
buf = Tensor(buf_impl);
114+
#endif
115+
torch::executor::aten::clone_outf(
116+
context, d_p, exec_aten::MemoryFormat::Contiguous, buf);
117+
if (context.failure_state() != Error::Ok) {
118+
return context.failure_state();
119+
}
120+
121+
// save the state of the momentum buffer to be reused in later
122+
// epochs
123+
auto state = std::make_unique<SGDParamState>(buf);
124+
state_[p.unsafeGetTensorImpl()] = std::move(state);
125+
} else {
126+
buf = static_cast<SGDParamState&>(*param_state->second)
127+
.momentum_buffer();
128+
129+
// update the momentum buffer and apply dampening
130+
torch::executor::aten::mul_outf(context, buf, momentum, buf);
131+
if (context.failure_state() != Error::Ok) {
132+
return context.failure_state();
133+
}
134+
torch::executor::aten::add_outf(
135+
context, buf, d_p, 1 - dampening, buf);
136+
if (context.failure_state() != Error::Ok) {
137+
return context.failure_state();
138+
}
139+
}
140+
if (nesterov) {
141+
// apply nesterov momentum
142+
torch::executor::aten::add_outf(context, d_p, buf, momentum, d_p);
143+
if (context.failure_state() != Error::Ok) {
144+
return context.failure_state();
145+
}
146+
} else {
147+
d_p = buf;
148+
}
149+
}
150+
// update the parameter using the gradient and learning rate
151+
torch::executor::aten::add_outf(
152+
context, p, d_p, -1 * options.lr(), p);
153+
if (context.failure_state() != Error::Ok) {
154+
return context.failure_state();
155+
}
156+
break;
157+
}
158+
}
159+
}
160+
}
161+
return Error::Ok;
162+
}
163+
164+
SGD::~SGD() {
165+
for (const auto& state_kv : state_) {
166+
auto state_tensor = static_cast<SGDParamState&>(*state_kv.second);
167+
free(state_tensor.momentum_buffer().unsafeGetTensorImpl()->mutable_data());
168+
#ifndef USE_ATEN_LIB
169+
delete state_tensor.momentum_buffer().unsafeGetTensorImpl();
170+
#endif
171+
}
172+
}
173+
} // namespace optimizer
174+
} // namespace training
175+
} // namespace executor
176+
} // namespace torch

extension/training/optimizer/sgd.h

Lines changed: 104 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,21 @@
1616
*/
1717
#pragma once
1818

19+
#include <executorch/runtime/core/error.h>
1920
#include <executorch/runtime/core/exec_aten/exec_aten.h>
21+
#include <executorch/runtime/core/span.h>
2022
#include <memory>
23+
#include <unordered_map>
24+
#include <vector>
2125

2226
namespace torch {
2327
namespace executor {
28+
namespace training {
2429
namespace optimizer {
2530

2631
using Tensor = exec_aten::Tensor;
32+
using TensorImpl = exec_aten::TensorImpl;
33+
using ScalarType = exec_aten::ScalarType;
2734

2835
/**
2936
* SGD optimizer state. This keeps track of the state of a given parameter to
@@ -123,16 +130,110 @@ class SGDOptions {
123130

124131
/**
125132
* SGD optimizer param group. This contains the parameters and
126-
* the OptimizerOptions associated to it.
133+
* the SGDOptions associated to it.
127134
*/
128-
class SGDParamGroup {};
135+
class SGDParamGroup {
136+
public:
137+
// NOTE: In order to store `SGDParamGroup` in a `std::vector`, it has
138+
// to be copy-constructible.
139+
SGDParamGroup(const SGDParamGroup& param_group)
140+
: param_data_(param_group.param_data()),
141+
param_names_(param_group.param_names()),
142+
options_(
143+
param_group.has_options() ? param_group.options().clone()
144+
: nullptr) {}
145+
SGDParamGroup& operator=(const SGDParamGroup& param_group) {
146+
this->param_data_ = param_group.param_data();
147+
this->param_names_ = param_group.param_names();
148+
this->options_ =
149+
param_group.has_options() ? param_group.options().clone() : nullptr;
150+
return *this;
151+
}
152+
153+
/**
154+
* This constructs a SGD param group. We expect that the two spans are of the
155+
* same size, and that for a given param data, its index in param_data is the
156+
* same as its param name in param_name.
157+
*
158+
* @param[in] param_names The names of the params for this group.
159+
* @param[in] param_data The tensors representing the param data.
160+
*/
161+
/* implicit */ SGDParamGroup(
162+
Span<const char*> param_names,
163+
Span<Tensor> param_data)
164+
: param_data_(std::move(param_data)),
165+
param_names_(std::move(param_names)) {}
166+
SGDParamGroup(
167+
Span<const char*> param_names,
168+
Span<Tensor> param_data,
169+
std::unique_ptr<SGDOptions> options)
170+
: param_data_(std::move(param_data)),
171+
param_names_(std::move(param_names)),
172+
options_(std::move(options)) {}
173+
174+
bool has_options() const;
175+
SGDOptions& options();
176+
const SGDOptions& options() const;
177+
void set_options(std::unique_ptr<SGDOptions> options);
178+
Span<const char*> param_names();
179+
const Span<const char*> param_names() const;
180+
Span<Tensor> param_data();
181+
const Span<Tensor> param_data() const;
182+
183+
private:
184+
Span<Tensor> param_data_;
185+
Span<const char*> param_names_;
186+
std::unique_ptr<SGDOptions> options_;
187+
};
129188

130189
/**
131190
* SGD optimizer class. This is responsible for performing the optimization
132191
* step.
133192
*/
134-
class SGD {};
193+
class SGD {
194+
public:
195+
explicit SGD(
196+
const std::vector<SGDParamGroup>& param_groups,
197+
SGDOptions defaults)
198+
: defaults_(std::make_unique<SGDOptions>(defaults)) {
199+
for (const auto& param_group : param_groups) {
200+
add_param_group(param_group);
201+
}
202+
}
203+
204+
explicit SGD(
205+
Span<const char*> param_names,
206+
Span<Tensor> param_data,
207+
SGDOptions defaults)
208+
: SGD({SGDParamGroup(std::move(param_names), std::move(param_data))},
209+
defaults) {}
210+
211+
// Adds the given param_group to the optimizer's param_group list.
212+
void add_param_group(const SGDParamGroup& param_group);
213+
214+
~SGD();
215+
216+
/**
217+
* Performs the optimization step.
218+
*
219+
* The two spans must be of the same size. It is expected that the gradient in
220+
* 'gradient_data' at index 'i' represents the gradient calculated in the loss
221+
* function for the parameter with the name in 'gradient_names' at index 'i'.
222+
*
223+
* @param[in] gradient_names The names of the params that matches the gradient
224+
* in 'gradient_data' at the same index.
225+
* @param[in] gradient_data The gradient tensors to be used for optimization
226+
* step.
227+
*/
228+
Error step(Span<const char*> gradient_names, Span<Tensor> gradient_data);
229+
230+
private:
231+
std::vector<SGDParamGroup> param_groups_;
232+
std::unordered_map<void*, std::unique_ptr<SGDParamState>> state_;
233+
std::unique_ptr<SGDOptions> defaults_;
234+
};
135235

136236
} // namespace optimizer
237+
} // namespace training
137238
} // namespace executor
138239
} // namespace torch

extension/training/optimizer/targets.bzl

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,33 @@ def define_common_targets():
1010
for aten_mode in (True, False):
1111
aten_suffix = "_aten" if aten_mode else ""
1212

13+
if aten_mode:
14+
kernel_deps = [
15+
"//executorch/kernels/aten:generated_lib",
16+
"//executorch/kernels/aten:generated_lib_headers",
17+
"//executorch/kernels/test:function_header_wrapper_aten",
18+
]
19+
else:
20+
kernel_deps = [
21+
"//executorch/kernels/portable/cpu:op_add",
22+
"//executorch/kernels/portable/cpu:op_mul",
23+
"//executorch/kernels/portable/cpu:op_clone",
24+
"//executorch/kernels/portable:generated_lib_headers",
25+
"//executorch/kernels/test:function_header_wrapper_portable",
26+
]
27+
1328
runtime.cxx_library(
1429
name = "optimizer" + aten_suffix,
30+
srcs = [
31+
"sgd.cpp",
32+
],
1533
exported_headers = [
1634
"sgd.h",
1735
],
1836
exported_deps = [
37+
"//executorch/runtime/kernel:kernel_runtime_context" + aten_suffix,
1938
"//executorch/runtime/core/exec_aten:lib" + aten_suffix,
20-
],
39+
] + kernel_deps,
2140
visibility = [
2241
"@EXECUTORCH_CLIENTS",
2342
],

0 commit comments

Comments
 (0)