Skip to content

Commit bc56a97

Browse files
Add op: convolution_backward
Differential Revision: D62028659 Pull Request resolved: #5032
1 parent 324864d commit bc56a97

File tree

8 files changed

+528
-2
lines changed

8 files changed

+528
-2
lines changed

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@
115115

116116
- op: convolution.out
117117

118+
- op: convolution_backward.out
119+
118120
- op: copy.out
119121

120122
- op: cos.out
Lines changed: 312 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,312 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <cstring>
10+
11+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
#include <tuple>
14+
15+
namespace torch {
16+
namespace executor {
17+
namespace native {
18+
19+
using Tensor = exec_aten::Tensor;
20+
using ScalarType = exec_aten::ScalarType;
21+
using IntArrayRef = exec_aten::ArrayRef<int64_t>;
22+
using OptIntArrayRef = exec_aten::OptionalArrayRef<int64_t>;
23+
24+
namespace {
25+
26+
bool check_convolution_backward_args(
27+
const Tensor& grad_output,
28+
const Tensor& input,
29+
const Tensor& weight,
30+
ET_UNUSED const OptIntArrayRef bias_sizes_opt,
31+
IntArrayRef stride,
32+
IntArrayRef padding,
33+
IntArrayRef dilation,
34+
bool transposed,
35+
IntArrayRef output_padding,
36+
int64_t groups,
37+
ET_UNUSED exec_aten::ArrayRef<bool> output_mask,
38+
Tensor& grad_input,
39+
Tensor& grad_weight,
40+
Tensor& grad_bias) {
41+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
42+
transposed == false, "Transposed Convolution Backward not supported yet");
43+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
44+
weight.dim() == 4, "Only 2D Convolution Backward supported for now");
45+
46+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(weight, input));
47+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_output, input));
48+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_input, input));
49+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_weight, input));
50+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(grad_bias, input));
51+
52+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
53+
check_convolution_args(
54+
input,
55+
weight,
56+
exec_aten::optional<Tensor>(),
57+
stride,
58+
padding,
59+
dilation,
60+
transposed,
61+
output_padding,
62+
groups,
63+
grad_output),
64+
"Invalid convolution arguments");
65+
66+
size_t output_ndim = 0;
67+
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
68+
get_convolution_out_target_size(
69+
input,
70+
weight,
71+
stride,
72+
padding,
73+
dilation,
74+
transposed,
75+
output_padding,
76+
groups,
77+
output_sizes,
78+
&output_ndim);
79+
80+
ET_LOG_AND_RETURN_IF_FALSE(
81+
output_size_is_valid({output_sizes, output_ndim}, input.dim() - 2));
82+
83+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
84+
grad_output.dim() == input.dim(),
85+
"grad_output should have same number of dimensions as input");
86+
87+
ET_LOG_AND_RETURN_IF_FALSE(
88+
tensor_has_expected_size(grad_output, {output_sizes, output_ndim}));
89+
90+
return true;
91+
}
92+
93+
template <typename CTYPE>
94+
void conv2d_backward_impl(
95+
const Tensor& grad_output,
96+
const Tensor& input,
97+
const Tensor& weight,
98+
IntArrayRef stride,
99+
IntArrayRef padding,
100+
IntArrayRef dilation,
101+
int64_t groups,
102+
exec_aten::ArrayRef<bool> output_mask,
103+
Tensor& grad_input,
104+
Tensor& grad_weight,
105+
Tensor& grad_bias) {
106+
auto batch_size = input.size(0);
107+
auto in_channels = input.size(1);
108+
auto out_channels = weight.size(0);
109+
auto in_height = input.size(2);
110+
auto in_width = input.size(3);
111+
auto out_height = grad_output.size(2);
112+
auto out_width = grad_output.size(3);
113+
auto kernel_height = weight.size(2);
114+
auto kernel_width = weight.size(3);
115+
116+
const int64_t stride_h = val_at(stride, 0);
117+
const int64_t padding_h = val_at(padding, 0, /*default_value=*/0);
118+
const int64_t dilation_h = val_at(dilation, 0);
119+
const int64_t stride_w = val_at(stride, 1);
120+
const int64_t padding_w = val_at(padding, 1, /*default_value=*/0);
121+
const int64_t dilation_w = val_at(dilation, 1);
122+
123+
auto in_channels_per_group = in_channels / groups;
124+
auto out_channels_per_group = out_channels / groups;
125+
126+
const CTYPE* grad_output_data = grad_output.const_data_ptr<CTYPE>();
127+
const CTYPE* input_data = input.const_data_ptr<CTYPE>();
128+
const CTYPE* weight_data = weight.const_data_ptr<CTYPE>();
129+
130+
CTYPE* grad_input_data = nullptr;
131+
CTYPE* grad_weight_data = nullptr;
132+
CTYPE* grad_bias_data = nullptr;
133+
134+
if (output_mask[0]) {
135+
grad_input_data = grad_input.mutable_data_ptr<CTYPE>();
136+
memset(grad_input_data, 0, grad_input.nbytes());
137+
}
138+
139+
if (output_mask[1]) {
140+
grad_weight_data = grad_weight.mutable_data_ptr<CTYPE>();
141+
memset(grad_weight_data, 0, grad_weight.nbytes());
142+
}
143+
144+
if (output_mask[2]) {
145+
grad_bias_data = grad_bias.mutable_data_ptr<CTYPE>();
146+
memset(grad_bias_data, 0, grad_bias.nbytes());
147+
}
148+
149+
// @lint-ignore CLANGTIDY facebook-hte-CArray
150+
exec_aten::SizesType out_coord[kTensorDimensionLimit];
151+
// @lint-ignore CLANGTIDY facebook-hte-CArray
152+
exec_aten::SizesType in_coord[kTensorDimensionLimit];
153+
// @lint-ignore CLANGTIDY facebook-hte-CArray
154+
exec_aten::SizesType weight_coord[kTensorDimensionLimit];
155+
156+
// Compute gradients
157+
for (int64_t b = 0; b < batch_size; ++b) { // Loop over each batch
158+
in_coord[0] = b;
159+
out_coord[0] = b;
160+
for (int64_t g = 0; g < groups; ++g) { // Loop over each group
161+
for (int64_t h = 0; h < out_height; ++h) { // Loop over each output row
162+
out_coord[2] = h;
163+
for (int64_t w = 0; w < out_width; ++w) { // Loop over each output col
164+
out_coord[3] = w;
165+
166+
// Loop over each output channel in the group
167+
for (int64_t oc = 0; oc < out_channels_per_group; ++oc) {
168+
int64_t oc_global = oc + g * out_channels_per_group;
169+
weight_coord[0] = oc_global;
170+
out_coord[1] = oc_global;
171+
172+
int64_t out_idx = calculate_linear_index(
173+
out_coord, grad_output.strides().data(), 4);
174+
175+
// Accumulate the gradient with respect to the bias if required
176+
if (output_mask[2]) {
177+
grad_bias_data[oc_global] += grad_output_data[out_idx];
178+
}
179+
180+
// Loop over each input channel in the group
181+
for (int64_t ic = 0; ic < in_channels_per_group; ++ic) {
182+
int64_t ic_global = ic + g * in_channels_per_group;
183+
in_coord[1] = ic_global;
184+
weight_coord[1] = ic;
185+
186+
// Loop over each element
187+
for (int64_t kh = 0; kh < kernel_height; ++kh) {
188+
int64_t in_h = h * stride_h - padding_h + kh * dilation_h;
189+
if (in_h >= 0 && in_h < in_height) {
190+
in_coord[2] = in_h;
191+
weight_coord[2] = kh;
192+
193+
for (int64_t kw = 0; kw < kernel_width; ++kw) {
194+
int64_t in_w = w * stride_w - padding_w + kw * dilation_w;
195+
if (in_w >= 0 && in_w < in_width) {
196+
in_coord[3] = in_w;
197+
weight_coord[3] = kw;
198+
199+
int64_t in_idx = calculate_linear_index(
200+
in_coord, input.strides().data(), 4);
201+
202+
int64_t weight_idx = calculate_linear_index(
203+
weight_coord, weight.strides().data(), 4);
204+
205+
// Gradient with respect to the input if required
206+
if (output_mask[0]) {
207+
grad_input_data[in_idx] +=
208+
grad_output_data[out_idx] * weight_data[weight_idx];
209+
}
210+
// Gradient with respect to the weight if required
211+
if (output_mask[1]) {
212+
grad_weight_data[weight_idx] +=
213+
grad_output_data[out_idx] * input_data[in_idx];
214+
}
215+
}
216+
}
217+
}
218+
}
219+
}
220+
}
221+
}
222+
}
223+
}
224+
}
225+
}
226+
227+
} // namespace
228+
229+
std::tuple<Tensor&, Tensor&, Tensor&> convolution_backward_out(
230+
RuntimeContext& ctx,
231+
const Tensor& grad_output,
232+
const Tensor& input,
233+
const Tensor& weight,
234+
const OptIntArrayRef bias_sizes_opt,
235+
IntArrayRef stride,
236+
IntArrayRef padding,
237+
IntArrayRef dilation,
238+
bool transposed,
239+
IntArrayRef output_padding,
240+
int64_t groups,
241+
exec_aten::ArrayRef<bool> output_mask,
242+
Tensor& grad_input,
243+
Tensor& grad_weight,
244+
Tensor& grad_bias) {
245+
(void)ctx;
246+
247+
std::tuple<Tensor&, Tensor&, Tensor&> ret_val(
248+
grad_input, grad_weight, grad_bias);
249+
250+
ET_KERNEL_CHECK(
251+
ctx,
252+
check_convolution_backward_args(
253+
grad_output,
254+
input,
255+
weight,
256+
bias_sizes_opt,
257+
stride,
258+
padding,
259+
dilation,
260+
transposed,
261+
output_padding,
262+
groups,
263+
output_mask,
264+
grad_input,
265+
grad_weight,
266+
grad_bias),
267+
InvalidArgument,
268+
ret_val);
269+
270+
ET_KERNEL_CHECK(
271+
ctx,
272+
resize_tensor(grad_input, input.sizes()) == Error::Ok,
273+
InvalidArgument,
274+
ret_val);
275+
276+
ET_KERNEL_CHECK(
277+
ctx,
278+
resize_tensor(grad_weight, weight.sizes()) == Error::Ok,
279+
InvalidArgument,
280+
ret_val);
281+
282+
if (bias_sizes_opt.has_value()) {
283+
ET_KERNEL_CHECK(
284+
ctx,
285+
resize_tensor(grad_bias, bias_sizes_opt.value()) == Error::Ok,
286+
InvalidArgument,
287+
ret_val);
288+
}
289+
290+
constexpr auto name = "convolution_backward.out";
291+
292+
ET_SWITCH_FLOATH_TYPES(input.scalar_type(), ctx, name, CTYPE, [&]() {
293+
conv2d_backward_impl<CTYPE>(
294+
grad_output,
295+
input,
296+
weight,
297+
stride,
298+
padding,
299+
dilation,
300+
groups,
301+
output_mask,
302+
grad_input,
303+
grad_weight,
304+
grad_bias);
305+
});
306+
307+
return ret_val;
308+
}
309+
310+
} // namespace native
311+
} // namespace executor
312+
} // namespace torch

kernels/portable/cpu/util/kernel_ops_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ bool check_convolution_args(
326326
bool transposed,
327327
IntArrayRef output_padding,
328328
int64_t groups,
329-
Tensor& out) {
329+
const Tensor& out) {
330330
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight, out));
331331

332332
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_default_or_channels_last_dim_order(in));

kernels/portable/cpu/util/kernel_ops_util.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ bool check_convolution_args(
411411
bool transposed,
412412
IntArrayRef output_padding,
413413
int64_t groups,
414-
Tensor& out);
414+
const Tensor& out);
415415

416416
void get_convolution_out_target_size(
417417
const Tensor& in,

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,11 @@
248248
- arg_meta: null
249249
kernel_name: torch::executor::convolution_out
250250

251+
- op: convolution_backward.out
252+
kernels:
253+
- arg_meta: null
254+
kernel_name: torch::executor::convolution_backward_out
255+
251256
- op: copy.out
252257
kernels:
253258
- arg_meta: null

0 commit comments

Comments
 (0)