8
8
9
9
#include < cstring>
10
10
11
+ #include < executorch/kernels/portable/cpu/util/dtype_util.h>
11
12
#include < executorch/kernels/portable/cpu/util/kernel_ops_util.h>
12
13
#include < executorch/runtime/core/exec_aten/util/dim_order_util.h>
13
14
#include < executorch/runtime/kernel/kernel_includes.h>
@@ -32,15 +33,17 @@ namespace {
32
33
* in_C_per_group x in_H x in_W, to compute an out channel of size 1 x out_H x
33
34
* out_W.
34
35
*/
35
- template <typename CTYPE, typename CTYPE_BIAS >
36
+ template <typename CTYPE, typename LoadFn = CTYPE (*)( const void *) >
36
37
void conv2d_impl (
37
38
const CTYPE* const in_ptr,
38
39
SizesArrayRef in_sizes,
39
40
StridesArrayRef in_strides,
40
41
const CTYPE* const w_ptr,
41
42
SizesArrayRef w_sizes,
42
43
StridesArrayRef w_strides,
43
- const CTYPE_BIAS* const bias_ptr,
44
+ const exec_aten::optional<Tensor>& bias,
45
+ const char * const bias_ptr,
46
+ LoadFn load_bias,
44
47
IntArrayRef stride,
45
48
IntArrayRef padding,
46
49
IntArrayRef dilation,
@@ -128,7 +131,7 @@ void conv2d_impl(
128
131
}
129
132
130
133
if (bias_ptr != nullptr ) {
131
- accum += convert<CTYPE, CTYPE_BIAS>( bias_ptr[out_c]);
134
+ accum += load_bias (& bias_ptr[out_c * bias. value (). element_size () ]);
132
135
}
133
136
size_t out_idx =
134
137
calculate_linear_index (out_coord, out_strides.data (), 4 );
@@ -185,11 +188,12 @@ void conv2d_impl(
185
188
}
186
189
}
187
190
188
- template <typename CTYPE, typename CTYPE_BIAS >
191
+ template <typename CTYPE, typename LoadFn = CTYPE (*)( const void *) >
189
192
void convolution_wrapper (
190
193
const Tensor& in,
191
194
const Tensor& weight,
192
195
const exec_aten::optional<Tensor>& bias,
196
+ LoadFn load_bias,
193
197
IntArrayRef stride,
194
198
IntArrayRef padding,
195
199
IntArrayRef dilation,
@@ -280,8 +284,9 @@ void convolution_wrapper(
280
284
CTYPE* const out_ptr = out.mutable_data_ptr <CTYPE>();
281
285
const CTYPE* const in_ptr = in.const_data_ptr <CTYPE>();
282
286
const CTYPE* const w_ptr = weight.const_data_ptr <CTYPE>();
283
- const CTYPE_BIAS* const bias_ptr =
284
- bias.has_value () ? bias.value ().const_data_ptr <CTYPE_BIAS>() : nullptr ;
287
+ const char * const bias_ptr = bias.has_value ()
288
+ ? reinterpret_cast <const char *>(bias.value ().const_data_ptr ())
289
+ : nullptr ;
285
290
286
291
size_t out_N = out.size (0 );
287
292
size_t out_C = out.size (1 );
@@ -296,8 +301,9 @@ void convolution_wrapper(
296
301
} else {
297
302
// If bias is present, we initialize the output to the bias value
298
303
for (size_t out_ix = 0 ; out_ix < out.numel (); ++out_ix) {
299
- out_ptr[out_ix] = convert<CTYPE, CTYPE_BIAS>(
300
- bias_ptr[(out_ix / out_strides[1 ]) % out_C]);
304
+ out_ptr[out_ix] = load_bias (&bias_ptr
305
+ [((out_ix / out_strides[1 ]) % out_C) *
306
+ bias.value ().element_size ()]);
301
307
}
302
308
}
303
309
}
@@ -316,7 +322,9 @@ void convolution_wrapper(
316
322
w_ptr,
317
323
weight_sizes,
318
324
{weight_strides, 4 },
325
+ bias,
319
326
bias_ptr,
327
+ load_bias,
320
328
stride_,
321
329
padding_,
322
330
dilation_,
@@ -398,19 +406,25 @@ Tensor& convolution_out(
398
406
return out;
399
407
}
400
408
401
- ScalarType in_type = in.scalar_type ();
402
- ScalarType bias_type = in_type;
403
- if (bias.has_value ()) {
404
- bias_type = bias.value ().scalar_type ();
405
- }
406
-
407
- constexpr auto name = " convolution.out" ;
408
-
409
- ET_SWITCH_REALH_TYPES (in_type, ctx, name, CTYPE, [&]() {
410
- ET_SWITCH_REALHB_TYPES (bias_type, ctx, name, CTYPE_BIAS, [&]() {
411
- convolution_wrapper<CTYPE, CTYPE_BIAS>(
412
- in, weight, bias, stride, padding, dilation, transposed, groups, out);
413
- });
409
+ // @lint-ignore CLANGTIDY facebook-hte-CArray
410
+ static constexpr const char name[] = " convolution.out" ;
411
+
412
+ ET_SWITCH_REALH_TYPES (in.scalar_type (), ctx, name, CTYPE, [&]() {
413
+ const auto load_bias = bias.has_value ()
414
+ ? utils::internal::get_load_to_common_fn<CTYPE, name>(
415
+ bias.value (), utils::SupportedTensorDtypes::REALHBF16)
416
+ : nullptr ;
417
+ convolution_wrapper<CTYPE>(
418
+ in,
419
+ weight,
420
+ bias,
421
+ load_bias,
422
+ stride,
423
+ padding,
424
+ dilation,
425
+ transposed,
426
+ groups,
427
+ out);
414
428
});
415
429
416
430
return out;
0 commit comments