@@ -280,7 +280,6 @@ template <typename To, typename From>
280
280
void convert_and_store (From f, void * dst) {
281
281
*reinterpret_cast <To*>(dst) = static_cast <To>(f);
282
282
}
283
- } // namespace internal
284
283
285
284
template <typename CTYPE_COMMON>
286
285
using load_to_common_fn = CTYPE_COMMON (*)(const void *);
@@ -296,6 +295,17 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
296
295
return result;
297
296
}
298
297
298
+ template <typename CTYPE_COMMON, const char * op_name>
299
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (
300
+ const Tensor& t) {
301
+ CTYPE_COMMON (*result)(const void *) = nullptr ;
302
+ ET_SWITCH_TWO_TYPES (
303
+ Bool, Byte, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
304
+ result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
305
+ });
306
+ return result;
307
+ }
308
+
299
309
template <typename CTYPE_COMMON>
300
310
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void *);
301
311
@@ -310,6 +320,75 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
310
320
return result;
311
321
}
312
322
323
+ template <typename CTYPE_COMMON, const char * op_name>
324
+ store_common_to_tensor_fn<CTYPE_COMMON>
325
+ get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
326
+ void (*result)(CTYPE_COMMON, void *) = nullptr ;
327
+ ET_SWITCH_TWO_TYPES (
328
+ Bool, Byte, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
329
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
330
+ });
331
+ return result;
332
+ }
333
+ } // namespace internal
334
+
335
+ enum class SupportedTensorDtypes {
336
+ REALHBBF16,
337
+ BOOL_OR_BYTE,
338
+ SAME_AS_COMMON,
339
+ };
340
+
341
+ namespace internal {
342
+ template <typename CTYPE_COMMON, const char * op_name>
343
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn (
344
+ const Tensor& t,
345
+ SupportedTensorDtypes dtypes) {
346
+ switch (dtypes) {
347
+ case SupportedTensorDtypes::REALHBBF16:
348
+ return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
349
+ case SupportedTensorDtypes::BOOL_OR_BYTE:
350
+ return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
351
+ case SupportedTensorDtypes::SAME_AS_COMMON: {
352
+ constexpr auto common_scalar_type =
353
+ CppTypeToScalarType<CTYPE_COMMON>::value;
354
+ ET_CHECK_MSG (
355
+ t.scalar_type () == common_scalar_type,
356
+ " Unhandled dtype %s for %s" ,
357
+ ::executorch::runtime::toString (common_scalar_type),
358
+ op_name);
359
+ return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
360
+ }
361
+ }
362
+ ET_CHECK (false );
363
+ return nullptr ;
364
+ }
365
+
366
+ template <typename CTYPE_COMMON, const char * op_name>
367
+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn (
368
+ const Tensor& t,
369
+ SupportedTensorDtypes dtypes) {
370
+ switch (dtypes) {
371
+ case SupportedTensorDtypes::REALHBBF16:
372
+ return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
373
+ case SupportedTensorDtypes::BOOL_OR_BYTE:
374
+ return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
375
+ t);
376
+ case SupportedTensorDtypes::SAME_AS_COMMON: {
377
+ constexpr auto common_scalar_type =
378
+ CppTypeToScalarType<CTYPE_COMMON>::value;
379
+ ET_CHECK_MSG (
380
+ t.scalar_type () == common_scalar_type,
381
+ " Unhandled dtype %s for %s" ,
382
+ ::executorch::runtime::toString (common_scalar_type),
383
+ op_name);
384
+ return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
385
+ }
386
+ }
387
+ ET_CHECK (false );
388
+ return nullptr ;
389
+ }
390
+ } // namespace internal
391
+
313
392
/* *
314
393
* Useful for binary elementwise operators. For each element of the inputs,
315
394
* perform a computation and write to the corresponding element of the output.
@@ -356,33 +435,45 @@ inline void apply_binary_elementwise_fn(
356
435
*
357
436
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
358
437
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
359
- * are passed as CTYPE_COMMON. We require compute_fun to return
360
- * CTYPE_COMMON, and we require loading conversion functions from each
361
- * input type to CTYPE_COMMON and a storing conversion from
362
- * CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
363
- * must take a void* pointing to an element of the corresponding
364
- * tensor, load that element, and convert it to CTYPE_COMMON. The
365
- * storing conversion function must have the signature
366
- * void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
367
- * and store it to the given location.
438
+ * are passed as CTYPE_COMMON.
439
+ *
440
+ * Each tensor's supported dtypes set must be provided. The tensor
441
+ * will be checked to ensure that its dtype falls into that set.
442
+ *
443
+ * op_name is used to support dtype selective build, as with the
444
+ * ET_SWITCH family of macros. Note: because of C++17 quirks, you
445
+ * can't pass a string literal for op_name. Instead, you should do the
446
+ * following:
447
+ *
448
+ * static constexpr const char op_name[] = "my_op";
449
+ * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
368
450
*/
369
- template <typename CTYPE_COMMON, typename Op>
451
+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
370
452
inline void apply_ternary_elementwise_fn (
371
453
const Op& compute_fun,
372
454
const Tensor& a,
455
+ SupportedTensorDtypes a_dtypes,
373
456
const Tensor& b,
457
+ SupportedTensorDtypes b_dtypes,
374
458
const Tensor& c,
459
+ SupportedTensorDtypes c_dtypes,
375
460
const Tensor& out,
376
- CTYPE_COMMON (*load_a_to_common)(const void *),
377
- CTYPE_COMMON (*load_b_to_common)(const void *),
378
- CTYPE_COMMON (*load_c_to_common)(const void *),
379
- void (*store_common_to_out)(CTYPE_COMMON, void *)) {
461
+ SupportedTensorDtypes out_dtypes) {
380
462
const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
381
463
const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
382
464
const bool c_is_broadcasted = !out.sizes ().equals (c.sizes ());
383
465
const bool any_is_broadcasted =
384
466
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
385
467
468
+ const auto load_a_to_common =
469
+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
470
+ const auto load_b_to_common =
471
+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
472
+ const auto load_c_to_common =
473
+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
474
+ const auto store_common_to_out =
475
+ internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
476
+ out, out_dtypes);
386
477
const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
387
478
const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
388
479
const char * const data_c = reinterpret_cast <const char *>(c.const_data_ptr ());
0 commit comments