Skip to content

Commit d8bb5d6

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Refactor + Migrate convolution to non-fatal failure pattern
Summary: Factors out utilities used in `op_convolution` to `kernel_ops_util`, which can be shared between all ops that use a 2D kernel pattern (i.e. accepts `stride`, `padding`, `dilation`. Reviewed By: guangy10 Differential Revision: D48405102 fbshipit-source-id: aa5d6d43aaa6ed5655f0fca99b4d0f59a00b060c
1 parent 8a38922 commit d8bb5d6

File tree

9 files changed

+545
-161
lines changed

9 files changed

+545
-161
lines changed

kernels/portable/cpu/op_convolution.cpp

Lines changed: 41 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cstring>
1010

11+
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1112
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1213
#include <executorch/runtime/kernel/kernel_includes.h>
1314

@@ -24,54 +25,6 @@ using StridesArrayRef = exec_aten::ArrayRef<exec_aten::StridesType>;
2425

2526
namespace {
2627

27-
/**
28-
* Extracts a value at index i from an int array. If the array length is 1, then
29-
* the first element will be returned regardless of what i is requested to
30-
* simulate broadcasting.
31-
*/
32-
inline int64_t val_at(IntArrayRef array, size_t i) {
33-
if (array.size() == 1) {
34-
return array[0];
35-
} else if (array.size() > 1) {
36-
return array[i];
37-
} else {
38-
ET_CHECK_MSG(false, "Attempted to retrieve from an empty array!");
39-
}
40-
}
41-
42-
inline void get_unsqueezed_sizes(
43-
const Tensor& t,
44-
int64_t unsqueeze_dim,
45-
exec_aten::SizesType* sizes_arr,
46-
size_t& ndim) {
47-
ndim = t.dim() + 1;
48-
for (int d = 0; d < unsqueeze_dim; ++d) {
49-
sizes_arr[d] = t.size(d);
50-
}
51-
sizes_arr[unsqueeze_dim] = 1;
52-
for (int d = (unsqueeze_dim + 1); d < ndim; d++) {
53-
sizes_arr[d] = t.size(d - 1);
54-
}
55-
}
56-
57-
inline void get_unsqueezed_dim_order(
58-
const Tensor& t,
59-
exec_aten::DimOrderType unsqueeze_dim,
60-
exec_aten::DimOrderType* dim_order_arr) {
61-
int offset = 0;
62-
for (int i = 0; i < t.dim(); ++i) {
63-
exec_aten::DimOrderType dim = t.dim_order()[i];
64-
if (dim == unsqueeze_dim) {
65-
dim_order_arr[i] = dim;
66-
dim_order_arr[i + 1] = dim + 1;
67-
offset = 1;
68-
} else {
69-
dim_order_arr[i + offset] = dim > unsqueeze_dim ? dim + 1 : dim;
70-
}
71-
}
72-
return;
73-
}
74-
7528
/**
7629
* Computes 2D convolution out results for a given group and channel. The
7730
* computation can be thought of as a stencil computation: we iterate over an
@@ -134,25 +87,19 @@ void conv2d_impl(
13487
for (size_t w_y = 0; w_y < w_H; ++w_y) {
13588
w_coord[2] = w_y;
13689

137-
int64_t dilation_y = 1;
138-
if (dilation.size() > 0) {
139-
dilation_y = val_at(dilation, 0);
140-
}
14190
int64_t stride_y = val_at(stride, 0);
142-
int64_t padding_y = val_at(padding, 0);
91+
int64_t padding_y = val_at(padding, 0, /*default_value=*/0);
92+
int64_t dilation_y = val_at(dilation, 0);
14393
size_t in_y = stride_y * out_y + dilation_y * w_y - padding_y;
14494
in_coord[2] = in_y;
14595
// Only proceed if input y coordinate is within bounds
14696
if (in_y >= 0 && in_y < in_H) {
14797
for (size_t w_x = 0; w_x < w_W; ++w_x) {
14898
w_coord[3] = w_x;
14999

150-
int64_t dilation_x = 1;
151-
if (dilation.size() > 0) {
152-
dilation_x = val_at(dilation, 0);
153-
}
154100
int64_t stride_x = val_at(stride, 1);
155-
int64_t padding_x = val_at(padding, 1);
101+
int64_t padding_x = val_at(padding, 1, /*default_value=*/0);
102+
int64_t dilation_x = val_at(dilation, 1);
156103
size_t in_x = stride_x * out_x + dilation_x * w_x - padding_x;
157104
in_coord[3] = in_x;
158105

@@ -265,26 +212,19 @@ void convolution_wrapper(
265212
}
266213

267214
exec_aten::StridesType in_strides[kTensorDimensionLimit];
268-
ET_CHECK(
269-
dim_order_to_stride(
270-
in_sizes.data(), in_dim_order.data(), in_sizes.size(), in_strides) ==
271-
Error::Ok);
215+
dim_order_to_stride_nocheck(
216+
in_sizes.data(), in_dim_order.data(), in_sizes.size(), in_strides);
272217

273218
exec_aten::StridesType weight_strides[kTensorDimensionLimit];
274-
ET_CHECK(
275-
dim_order_to_stride(
276-
weight_sizes.data(),
277-
weight_dim_order.data(),
278-
weight_sizes.size(),
279-
weight_strides) == Error::Ok);
219+
dim_order_to_stride_nocheck(
220+
weight_sizes.data(),
221+
weight_dim_order.data(),
222+
weight_sizes.size(),
223+
weight_strides);
280224

281225
exec_aten::StridesType out_strides[kTensorDimensionLimit];
282-
ET_CHECK(
283-
dim_order_to_stride(
284-
out_sizes.data(),
285-
out_dim_order.data(),
286-
out_sizes.size(),
287-
out_strides) == Error::Ok);
226+
dim_order_to_stride_nocheck(
227+
out_sizes.data(), out_dim_order.data(), out_sizes.size(), out_strides);
288228

289229
CTYPE* const out_ptr = out.mutable_data_ptr<CTYPE>();
290230
const CTYPE* const in_ptr = in.const_data_ptr<CTYPE>();
@@ -322,72 +262,6 @@ void convolution_wrapper(
322262
}
323263
}
324264

325-
void get_conv_output_size(
326-
const Tensor& in,
327-
const Tensor& weight,
328-
IntArrayRef stride,
329-
IntArrayRef padding,
330-
IntArrayRef dilation,
331-
exec_aten::SizesType* sizes_arr,
332-
size_t& dim) {
333-
dim = in.dim();
334-
335-
sizes_arr[0] = in.size(0);
336-
sizes_arr[1] = weight.size(0);
337-
for (size_t d = 2; d < in.dim(); ++d) {
338-
int64_t dilation_val = 1;
339-
if (dilation.size() > 1) {
340-
dilation_val = val_at(dilation, d - 2);
341-
}
342-
int64_t padding_val = val_at(padding, d - 2);
343-
int64_t stride_val = val_at(stride, d - 2);
344-
345-
int64_t kernel_len = dilation_val * (weight.size(d) - 1) + 1;
346-
sizes_arr[d] =
347-
(in.size(d) + (2 * padding_val) - kernel_len) / stride_val + 1;
348-
}
349-
}
350-
351-
void check_preconditions(
352-
const Tensor& in,
353-
const Tensor& weight,
354-
const exec_aten::optional<Tensor>& bias,
355-
IntArrayRef stride,
356-
IntArrayRef padding,
357-
IntArrayRef dilation,
358-
bool transposed,
359-
IntArrayRef output_padding,
360-
int64_t groups,
361-
Tensor& out) {
362-
ET_CHECK_SAME_DTYPE3(in, weight, out);
363-
364-
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(in);
365-
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(weight);
366-
ET_CHECK_DEFAULT_OR_CHANNELSLAST_DIMORDER(out);
367-
368-
ET_CHECK(in.dim() >= 3 && in.dim() < 5);
369-
ET_CHECK(in.dim() == weight.dim());
370-
ET_CHECK(in.dim() == out.dim());
371-
372-
if (bias.has_value()) {
373-
ET_CHECK(bias.value().dim() == 1);
374-
ET_CHECK(bias.value().size(0) == weight.size(0));
375-
}
376-
377-
ET_CHECK(padding.size() > 0 && padding.size() <= in.dim() - 2);
378-
ET_CHECK(stride.size() > 0 && stride.size() <= in.dim() - 2);
379-
if (dilation.size() > 0) {
380-
ET_CHECK(dilation.size() <= in.dim() - 2);
381-
}
382-
// input channels must be evenly divisible by groups
383-
ET_CHECK(in.size(1) % groups == 0);
384-
385-
ET_CHECK_MSG(!transposed, "transposed convolution not supported yet!");
386-
if (output_padding.size() > 0) {
387-
ET_CHECK(dilation.size() <= in.dim() - 2);
388-
}
389-
}
390-
391265
} // namespace
392266

393267
Tensor& convolution_out(
@@ -404,25 +278,38 @@ Tensor& convolution_out(
404278
Tensor& out) {
405279
(void)ctx;
406280

407-
check_preconditions(
408-
in,
409-
weight,
410-
bias,
411-
stride,
412-
padding,
413-
dilation,
414-
transposed,
415-
output_padding,
416-
groups,
281+
ET_KERNEL_CHECK(
282+
ctx,
283+
check_convolution_args(
284+
in,
285+
weight,
286+
bias,
287+
stride,
288+
padding,
289+
dilation,
290+
transposed,
291+
output_padding,
292+
groups,
293+
out),
294+
InvalidArgument,
417295
out);
418296

419297
size_t output_ndim = 0;
420298
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
421-
get_conv_output_size(
422-
in, weight, stride, padding, dilation, output_sizes, output_ndim);
299+
get_convolution_out_target_size(
300+
in, weight, stride, padding, dilation, output_sizes, &output_ndim);
423301

424-
Error err = resize_tensor(out, {output_sizes, output_ndim});
425-
ET_CHECK_MSG(err == Error::Ok, "Could not resize output");
302+
ET_KERNEL_CHECK(
303+
ctx,
304+
output_size_is_valid({output_sizes, output_ndim}),
305+
InvalidArgument,
306+
out);
307+
308+
ET_KERNEL_CHECK(
309+
ctx,
310+
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
311+
InvalidArgument,
312+
out);
426313

427314
ScalarType in_type = in.scalar_type();
428315
ScalarType bias_type = in_type;

kernels/portable/cpu/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ _ATEN_OPS = (
201201
op_target(
202202
name = "op_convolution",
203203
deps = [
204+
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
204205
":vec_ops",
205206
],
206207
),

0 commit comments

Comments
 (0)