Skip to content

Commit 0a9ef70

Browse files
committed
[ExecuTorch] Just pass SupportedTensorDtypes for each tensor to apply_ternary_elementwise_fn
No more function pointers! Also, we check that each Tensor's type is in the allowed set. Differential Revision: [D63794199](https://our.internmc.facebook.com/intern/diff/D63794199/) ghstack-source-id: 246019385 Pull Request resolved: #5834
1 parent ebdf3d1 commit 0a9ef70

File tree

3 files changed

+116
-34
lines changed

3 files changed

+116
-34
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ Tensor& clamp_tensor_out(
215215
static constexpr const char op_name[] = "clamp.Tensor_out";
216216

217217
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
218-
apply_ternary_elementwise_fn<CTYPE_COMMON>(
218+
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
219219
[has_min, has_max](
220220
const CTYPE_COMMON val_in,
221221
const CTYPE_COMMON val_min,
@@ -230,13 +230,13 @@ Tensor& clamp_tensor_out(
230230
return val_out;
231231
},
232232
in,
233+
SupportedTensorDtypes::REALHBBF16,
233234
min,
235+
SupportedTensorDtypes::REALHBBF16,
234236
max,
237+
SupportedTensorDtypes::REALHBBF16,
235238
out,
236-
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(in),
237-
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(min),
238-
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(max),
239-
get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(out));
239+
SupportedTensorDtypes::REALHBBF16);
240240
});
241241

242242
return out;

kernels/portable/cpu/op_where.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,25 @@ Tensor& where_out(
3838
ET_KERNEL_CHECK(
3939
ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);
4040

41-
constexpr auto name = "where.self_out";
41+
static constexpr const char op_name[] = "where.self_out";
4242

4343
ET_CHECK_MSG(
4444
cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
4545
"Unhandled dtype %s for where.self_out",
4646
torch::executor::toString(cond_type));
47-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
48-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
49-
using CTYPE_OUT =
50-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
51-
apply_ternary_elementwise_fn<CTYPE_OUT>(
52-
[](const CTYPE_OUT val_a,
53-
const CTYPE_OUT val_b,
54-
const CTYPE_OUT val_c) { return val_c ? val_a : val_b; },
47+
ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
48+
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
49+
[](const CTYPE_COMMON val_a,
50+
const CTYPE_COMMON val_b,
51+
const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; },
5552
a,
53+
SupportedTensorDtypes::REALHBBF16,
5654
b,
55+
SupportedTensorDtypes::REALHBBF16,
5756
cond,
57+
SupportedTensorDtypes::BOOL_OR_BYTE,
5858
out,
59-
internal::load_and_convert<CTYPE_OUT, CTYPE_A>,
60-
internal::load_and_convert<CTYPE_OUT, CTYPE_B>,
61-
internal::load_and_convert<CTYPE_OUT, uint8_t>,
62-
internal::convert_and_store<CTYPE_OUT, CTYPE_OUT>);
63-
});
59+
SupportedTensorDtypes::SAME_AS_COMMON);
6460
});
6561

6662
return out;

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 101 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ template <typename To, typename From>
280280
void convert_and_store(From f, void* dst) {
281281
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
282282
}
283-
} // namespace internal
284283

285284
template <typename CTYPE_COMMON>
286285
using load_to_common_fn = CTYPE_COMMON (*)(const void*);
@@ -296,6 +295,15 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
296295
return result;
297296
}
298297

298+
template <typename CTYPE_COMMON, const char* op_name>
299+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(const Tensor& t) {
300+
CTYPE_COMMON (*result)(const void*) = nullptr;
301+
ET_SWITCH_TWO_TYPES(Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
302+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
303+
});
304+
return result;
305+
}
306+
299307
template <typename CTYPE_COMMON>
300308
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
301309

@@ -310,6 +318,72 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
310318
return result;
311319
}
312320

321+
template <typename CTYPE_COMMON, const char* op_name>
322+
store_common_to_tensor_fn<CTYPE_COMMON>
323+
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
324+
void (*result)(CTYPE_COMMON, void*) = nullptr;
325+
ET_SWITCH_TWO_TYPES(Bool, Byte,
326+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
327+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
328+
});
329+
return result;
330+
}
331+
} // namespace internal
332+
333+
enum class SupportedTensorDtypes {
334+
REALHBBF16,
335+
BOOL_OR_BYTE,
336+
SAME_AS_COMMON,
337+
};
338+
339+
namespace internal {
340+
template <typename CTYPE_COMMON, const char* op_name>
341+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
342+
const Tensor& t,
343+
SupportedTensorDtypes dtypes) {
344+
switch (dtypes) {
345+
case SupportedTensorDtypes::REALHBBF16:
346+
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
347+
case SupportedTensorDtypes::BOOL_OR_BYTE:
348+
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
349+
case SupportedTensorDtypes::SAME_AS_COMMON: {
350+
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
351+
ET_CHECK_MSG(
352+
t.scalar_type() == common_scalar_type,
353+
"Unhandled dtype %s for %s",
354+
::executorch::runtime::toString(common_scalar_type),
355+
op_name);
356+
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
357+
}
358+
}
359+
ET_CHECK(false);
360+
return nullptr;
361+
}
362+
363+
template <typename CTYPE_COMMON, const char* op_name>
364+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
365+
const Tensor& t,
366+
SupportedTensorDtypes dtypes) {
367+
switch (dtypes) {
368+
case SupportedTensorDtypes::REALHBBF16:
369+
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
370+
case SupportedTensorDtypes::BOOL_OR_BYTE:
371+
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
372+
case SupportedTensorDtypes::SAME_AS_COMMON: {
373+
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
374+
ET_CHECK_MSG(
375+
t.scalar_type() == common_scalar_type,
376+
"Unhandled dtype %s for %s",
377+
::executorch::runtime::toString(common_scalar_type),
378+
op_name);
379+
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
380+
}
381+
}
382+
ET_CHECK(false);
383+
return nullptr;
384+
}
385+
} // namespace internal
386+
313387
/**
314388
* Useful for binary elementwise operators. For each element of the inputs,
315389
* perform a computation and write to the corresponding element of the output.
@@ -356,33 +430,45 @@ inline void apply_binary_elementwise_fn(
356430
*
357431
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
358432
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
359-
* are passed as CTYPE_COMMON. We require compute_fun to return
360-
* CTYPE_COMMON, and we require loading conversion functions from each
361-
* input type to CTYPE_COMMON and a storing conversion from
362-
* CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
363-
* must take a void* pointing to an element of the corresponding
364-
* tensor, load that element, and convert it to CTYPE_COMMON. The
365-
* storing conversion function must have the signature
366-
* void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
367-
* and store it to the given location.
433+
* are passed as CTYPE_COMMON.
434+
*
435+
* Each tensor's supported dtypes set must be provided. The tensor
436+
* will be checked to ensure that its dtype falls into that set.
437+
*
438+
* op_name is used to support dtype selective build, as with the
439+
* ET_SWITCH family of macros. Note: because of C++17 quirks, you
440+
* can't pass a string literal for op_name. Instead, you should do the
441+
* following:
442+
*
443+
* static constexpr const char op_name[] = "my_op";
444+
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
368445
*/
369-
template <typename CTYPE_COMMON, typename Op>
446+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
370447
inline void apply_ternary_elementwise_fn(
371448
const Op& compute_fun,
372449
const Tensor& a,
450+
SupportedTensorDtypes a_dtypes,
373451
const Tensor& b,
452+
SupportedTensorDtypes b_dtypes,
374453
const Tensor& c,
454+
SupportedTensorDtypes c_dtypes,
375455
const Tensor& out,
376-
CTYPE_COMMON (*load_a_to_common)(const void*),
377-
CTYPE_COMMON (*load_b_to_common)(const void*),
378-
CTYPE_COMMON (*load_c_to_common)(const void*),
379-
void (*store_common_to_out)(CTYPE_COMMON, void*)) {
456+
SupportedTensorDtypes out_dtypes) {
380457
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
381458
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
382459
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
383460
const bool any_is_broadcasted =
384461
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
385462

463+
const auto load_a_to_common =
464+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
465+
const auto load_b_to_common =
466+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
467+
const auto load_c_to_common =
468+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
469+
const auto store_common_to_out =
470+
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
471+
out, out_dtypes);
386472
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
387473
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
388474
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());

0 commit comments

Comments
 (0)