Skip to content

Commit be86a2c

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Reduce build size of op_convolution (#6017)
Summary: Pull Request resolved: #6017 500 K -> 52 K ghstack-source-id: 246985132 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63994876 fbshipit-source-id: ca04e87a87b149b925efad6c6168a848f006fa41
1 parent b308744 commit be86a2c

File tree

2 files changed

+36
-21
lines changed

2 files changed

+36
-21
lines changed

kernels/portable/cpu/op_convolution.cpp

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <cstring>
1010

11+
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
1112
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
1213
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1314
#include <executorch/runtime/kernel/kernel_includes.h>
@@ -32,15 +33,17 @@ namespace {
3233
* in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x
3334
* out_W.
3435
*/
35-
template <typename CTYPE, typename CTYPE_BIAS>
36+
template <typename CTYPE, typename LoadFn = CTYPE (*)(const void*)>
3637
void conv2d_impl(
3738
const CTYPE* const in_ptr,
3839
SizesArrayRef in_sizes,
3940
StridesArrayRef in_strides,
4041
const CTYPE* const w_ptr,
4142
SizesArrayRef w_sizes,
4243
StridesArrayRef w_strides,
43-
const CTYPE_BIAS* const bias_ptr,
44+
const exec_aten::optional<Tensor>& bias,
45+
const char* const bias_ptr,
46+
LoadFn load_bias,
4447
IntArrayRef stride,
4548
IntArrayRef padding,
4649
IntArrayRef dilation,
@@ -128,7 +131,7 @@ void conv2d_impl(
128131
}
129132

130133
if (bias_ptr != nullptr) {
131-
accum += convert<CTYPE, CTYPE_BIAS>(bias_ptr[out_c]);
134+
accum += load_bias(&bias_ptr[out_c * bias.value().element_size()]);
132135
}
133136
size_t out_idx =
134137
calculate_linear_index(out_coord, out_strides.data(), 4);
@@ -185,11 +188,12 @@ void conv2d_impl(
185188
}
186189
}
187190

188-
template <typename CTYPE, typename CTYPE_BIAS>
191+
template <typename CTYPE, typename LoadFn = CTYPE (*)(const void*)>
189192
void convolution_wrapper(
190193
const Tensor& in,
191194
const Tensor& weight,
192195
const exec_aten::optional<Tensor>& bias,
196+
LoadFn load_bias,
193197
IntArrayRef stride,
194198
IntArrayRef padding,
195199
IntArrayRef dilation,
@@ -280,8 +284,9 @@ void convolution_wrapper(
280284
CTYPE* const out_ptr = out.mutable_data_ptr<CTYPE>();
281285
const CTYPE* const in_ptr = in.const_data_ptr<CTYPE>();
282286
const CTYPE* const w_ptr = weight.const_data_ptr<CTYPE>();
283-
const CTYPE_BIAS* const bias_ptr =
284-
bias.has_value() ? bias.value().const_data_ptr<CTYPE_BIAS>() : nullptr;
287+
const char* const bias_ptr = bias.has_value()
288+
? reinterpret_cast<const char*>(bias.value().const_data_ptr())
289+
: nullptr;
285290

286291
size_t out_N = out.size(0);
287292
size_t out_C = out.size(1);
@@ -296,8 +301,9 @@ void convolution_wrapper(
296301
} else {
297302
// If bias is present, we initialize the output to the bias value
298303
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
299-
out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(
300-
bias_ptr[(out_ix / out_strides[1]) % out_C]);
304+
out_ptr[out_ix] = load_bias(&bias_ptr
305+
[((out_ix / out_strides[1]) % out_C) *
306+
bias.value().element_size()]);
301307
}
302308
}
303309
}
@@ -316,7 +322,9 @@ void convolution_wrapper(
316322
w_ptr,
317323
weight_sizes,
318324
{weight_strides, 4},
325+
bias,
319326
bias_ptr,
327+
load_bias,
320328
stride_,
321329
padding_,
322330
dilation_,
@@ -398,19 +406,25 @@ Tensor& convolution_out(
398406
return out;
399407
}
400408

401-
ScalarType in_type = in.scalar_type();
402-
ScalarType bias_type = in_type;
403-
if (bias.has_value()) {
404-
bias_type = bias.value().scalar_type();
405-
}
406-
407-
constexpr auto name = "convolution.out";
408-
409-
ET_SWITCH_REALH_TYPES(in_type, ctx, name, CTYPE, [&]() {
410-
ET_SWITCH_REALHB_TYPES(bias_type, ctx, name, CTYPE_BIAS, [&]() {
411-
convolution_wrapper<CTYPE, CTYPE_BIAS>(
412-
in, weight, bias, stride, padding, dilation, transposed, groups, out);
413-
});
409+
// @lint-ignore CLANGTIDY facebook-hte-CArray
410+
static constexpr const char name[] = "convolution.out";
411+
412+
ET_SWITCH_REALH_TYPES(in.scalar_type(), ctx, name, CTYPE, [&]() {
413+
const auto load_bias = bias.has_value()
414+
? utils::internal::get_load_to_common_fn<CTYPE, name>(
415+
bias.value(), utils::SupportedTensorDtypes::REALHBF16)
416+
: nullptr;
417+
convolution_wrapper<CTYPE>(
418+
in,
419+
weight,
420+
bias,
421+
load_bias,
422+
stride,
423+
padding,
424+
dilation,
425+
transposed,
426+
groups,
427+
out);
414428
});
415429

416430
return out;

shim/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ ATEN_OPS = (
408408
op_target(
409409
name = "op_convolution",
410410
deps = [
411+
"//executorch/kernels/portable/cpu/util:dtype_util",
411412
"//executorch/kernels/portable/cpu/util:kernel_ops_util",
412413
":vec_ops",
413414
],

0 commit comments

Comments
 (0)