@@ -280,7 +280,6 @@ template <typename To, typename From>
280
280
void convert_and_store (From f, void * dst) {
281
281
*reinterpret_cast <To*>(dst) = static_cast <To>(f);
282
282
}
283
- } // namespace internal
284
283
285
284
template <typename CTYPE_COMMON>
286
285
using load_to_common_fn = CTYPE_COMMON (*)(const void *);
@@ -296,6 +295,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
296
295
return result;
297
296
}
298
297
298
+ template <typename CTYPE_COMMON, const char * op_name>
299
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte (const Tensor& t) {
300
+ CTYPE_COMMON (*result)(const void *) = nullptr ;
301
+ ET_SWITCH_TWO_TYPES (Bool, Byte, t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
302
+ result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
303
+ });
304
+ return result;
305
+ }
306
+
299
307
template <typename CTYPE_COMMON>
300
308
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void *);
301
309
@@ -310,6 +318,72 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
310
318
return result;
311
319
}
312
320
321
+ template <typename CTYPE_COMMON, const char * op_name>
322
+ store_common_to_tensor_fn<CTYPE_COMMON>
323
+ get_store_common_to_tensor_fn_bool_or_byte (const Tensor& t) {
324
+ void (*result)(CTYPE_COMMON, void *) = nullptr ;
325
+ ET_SWITCH_TWO_TYPES (Bool, Byte,
326
+ t.scalar_type (), unused, op_name, TENSOR_CTYPE, [&]() {
327
+ result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
328
+ });
329
+ return result;
330
+ }
331
+ } // namespace internal
332
+
333
+ enum class SupportedTensorDtypes {
334
+ REALHBBF16,
335
+ BOOL_OR_BYTE,
336
+ SAME_AS_COMMON,
337
+ };
338
+
339
+ namespace internal {
340
+ template <typename CTYPE_COMMON, const char * op_name>
341
+ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn (
342
+ const Tensor& t,
343
+ SupportedTensorDtypes dtypes) {
344
+ switch (dtypes) {
345
+ case SupportedTensorDtypes::REALHBBF16:
346
+ return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
347
+ case SupportedTensorDtypes::BOOL_OR_BYTE:
348
+ return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
349
+ case SupportedTensorDtypes::SAME_AS_COMMON: {
350
+ constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
351
+ ET_CHECK_MSG (
352
+ t.scalar_type () == common_scalar_type,
353
+ " Unhandled dtype %s for %s" ,
354
+ ::executorch::runtime::toString (common_scalar_type),
355
+ op_name);
356
+ return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
357
+ }
358
+ }
359
+ ET_CHECK (false );
360
+ return nullptr ;
361
+ }
362
+
363
+ template <typename CTYPE_COMMON, const char * op_name>
364
+ store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn (
365
+ const Tensor& t,
366
+ SupportedTensorDtypes dtypes) {
367
+ switch (dtypes) {
368
+ case SupportedTensorDtypes::REALHBBF16:
369
+ return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
370
+ case SupportedTensorDtypes::BOOL_OR_BYTE:
371
+ return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
372
+ case SupportedTensorDtypes::SAME_AS_COMMON: {
373
+ constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
374
+ ET_CHECK_MSG (
375
+ t.scalar_type () == common_scalar_type,
376
+ " Unhandled dtype %s for %s" ,
377
+ ::executorch::runtime::toString (common_scalar_type),
378
+ op_name);
379
+ return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
380
+ }
381
+ }
382
+ ET_CHECK (false );
383
+ return nullptr ;
384
+ }
385
+ } // namespace internal
386
+
313
387
/* *
314
388
* Useful for binary elementwise operators. For each element of the inputs,
315
389
* perform a computation and write to the corresponding element of the output.
@@ -356,33 +430,45 @@ inline void apply_binary_elementwise_fn(
356
430
*
357
431
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
358
432
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
359
- * are passed as CTYPE_COMMON. We require compute_fun to return
360
- * CTYPE_COMMON, and we require loading conversion functions from each
361
- * input type to CTYPE_COMMON and a storing conversion from
362
- * CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
363
- * must take a void* pointing to an element of the corresponding
364
- * tensor, load that element, and convert it to CTYPE_COMMON. The
365
- * storing conversion function must have the signature
366
- * void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
367
- * and store it to the given location.
433
+ * are passed as CTYPE_COMMON.
434
+ *
435
+ * Each tensor's supported dtypes set must be provided. The tensor
436
+ * will be checked to ensure that its dtype falls into that set.
437
+ *
438
+ * op_name is used to support dtype selective build, as with the
439
+ * ET_SWITCH family of macros. Note: because of C++17 quirks, you
440
+ * can't pass a string literal for op_name. Instead, you should do the
441
+ * following:
442
+ *
443
+ * static constexpr const char op_name[] = "my_op";
444
+ * apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
368
445
*/
369
- template <typename CTYPE_COMMON, typename Op>
446
+ template <typename CTYPE_COMMON, const char * op_name, typename Op>
370
447
inline void apply_ternary_elementwise_fn (
371
448
const Op& compute_fun,
372
449
const Tensor& a,
450
+ SupportedTensorDtypes a_dtypes,
373
451
const Tensor& b,
452
+ SupportedTensorDtypes b_dtypes,
374
453
const Tensor& c,
454
+ SupportedTensorDtypes c_dtypes,
375
455
const Tensor& out,
376
- CTYPE_COMMON (*load_a_to_common)(const void *),
377
- CTYPE_COMMON (*load_b_to_common)(const void *),
378
- CTYPE_COMMON (*load_c_to_common)(const void *),
379
- void (*store_common_to_out)(CTYPE_COMMON, void *)) {
456
+ SupportedTensorDtypes out_dtypes) {
380
457
const bool a_is_broadcasted = !out.sizes ().equals (a.sizes ());
381
458
const bool b_is_broadcasted = !out.sizes ().equals (b.sizes ());
382
459
const bool c_is_broadcasted = !out.sizes ().equals (c.sizes ());
383
460
const bool any_is_broadcasted =
384
461
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
385
462
463
+ const auto load_a_to_common =
464
+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
465
+ const auto load_b_to_common =
466
+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
467
+ const auto load_c_to_common =
468
+ internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
469
+ const auto store_common_to_out =
470
+ internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
471
+ out, out_dtypes);
386
472
const char * const data_a = reinterpret_cast <const char *>(a.const_data_ptr ());
387
473
const char * const data_b = reinterpret_cast <const char *>(b.const_data_ptr ());
388
474
const char * const data_c = reinterpret_cast <const char *>(c.const_data_ptr ());
0 commit comments