Skip to content

Commit 433ead0

Browse files
swolchokfacebook-github-bot
authored andcommitted
Just pass SupportedTensorDtypes for each tensor to apply_ternary_elementwise_fn (#5834)
Summary: Pull Request resolved: #5834 No more function pointers! Also, we check that each Tensor's type is in the allowed set. ghstack-source-id: 246132261 exported-using-ghexport Reviewed By: manuelcandales Differential Revision: D63794199 fbshipit-source-id: c811b0359f0465ad6fa0a7e742ecb932433a4122
1 parent b1fd74c commit 433ead0

File tree

3 files changed

+125
-38
lines changed

3 files changed

+125
-38
lines changed

kernels/portable/cpu/op_clamp.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ Tensor& clamp_tensor_out(
215215
static constexpr const char op_name[] = "clamp.Tensor_out";
216216

217217
ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
218-
apply_ternary_elementwise_fn<CTYPE_COMMON>(
218+
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
219219
[has_min, has_max](
220220
const CTYPE_COMMON val_in,
221221
const CTYPE_COMMON val_min,
@@ -230,13 +230,13 @@ Tensor& clamp_tensor_out(
230230
return val_out;
231231
},
232232
in,
233+
SupportedTensorDtypes::REALHBBF16,
233234
min,
235+
SupportedTensorDtypes::REALHBBF16,
234236
max,
237+
SupportedTensorDtypes::REALHBBF16,
235238
out,
236-
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(in),
237-
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(min),
238-
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(max),
239-
get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(out));
239+
SupportedTensorDtypes::REALHBBF16);
240240
});
241241

242242
return out;

kernels/portable/cpu/op_where.cpp

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,29 +38,25 @@ Tensor& where_out(
3838
ET_KERNEL_CHECK(
3939
ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);
4040

41-
constexpr auto name = "where.self_out";
41+
static constexpr const char op_name[] = "where.self_out";
4242

4343
ET_CHECK_MSG(
4444
cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
4545
"Unhandled dtype %s for where.self_out",
4646
torch::executor::toString(cond_type));
47-
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
48-
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
49-
using CTYPE_OUT =
50-
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
51-
apply_ternary_elementwise_fn<CTYPE_OUT>(
52-
[](const CTYPE_OUT val_a,
53-
const CTYPE_OUT val_b,
54-
const CTYPE_OUT val_c) { return val_c ? val_a : val_b; },
55-
a,
56-
b,
57-
cond,
58-
out,
59-
internal::load_and_convert<CTYPE_OUT, CTYPE_A>,
60-
internal::load_and_convert<CTYPE_OUT, CTYPE_B>,
61-
internal::load_and_convert<CTYPE_OUT, uint8_t>,
62-
internal::convert_and_store<CTYPE_OUT, CTYPE_OUT>);
63-
});
47+
ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
48+
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
49+
[](const CTYPE_COMMON val_a,
50+
const CTYPE_COMMON val_b,
51+
const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; },
52+
a,
53+
SupportedTensorDtypes::REALHBBF16,
54+
b,
55+
SupportedTensorDtypes::REALHBBF16,
56+
cond,
57+
SupportedTensorDtypes::BOOL_OR_BYTE,
58+
out,
59+
SupportedTensorDtypes::SAME_AS_COMMON);
6460
});
6561

6662
return out;

kernels/portable/cpu/util/broadcast_util.h

Lines changed: 106 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,6 @@ template <typename To, typename From>
280280
void convert_and_store(From f, void* dst) {
281281
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
282282
}
283-
} // namespace internal
284283

285284
template <typename CTYPE_COMMON>
286285
using load_to_common_fn = CTYPE_COMMON (*)(const void*);
@@ -296,6 +295,17 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
296295
return result;
297296
}
298297

298+
template <typename CTYPE_COMMON, const char* op_name>
299+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
300+
const Tensor& t) {
301+
CTYPE_COMMON (*result)(const void*) = nullptr;
302+
ET_SWITCH_TWO_TYPES(
303+
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
304+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
305+
});
306+
return result;
307+
}
308+
299309
template <typename CTYPE_COMMON>
300310
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
301311

@@ -310,6 +320,75 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
310320
return result;
311321
}
312322

323+
template <typename CTYPE_COMMON, const char* op_name>
324+
store_common_to_tensor_fn<CTYPE_COMMON>
325+
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
326+
void (*result)(CTYPE_COMMON, void*) = nullptr;
327+
ET_SWITCH_TWO_TYPES(
328+
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
329+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
330+
});
331+
return result;
332+
}
333+
} // namespace internal
334+
335+
enum class SupportedTensorDtypes {
336+
REALHBBF16,
337+
BOOL_OR_BYTE,
338+
SAME_AS_COMMON,
339+
};
340+
341+
namespace internal {
342+
template <typename CTYPE_COMMON, const char* op_name>
343+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
344+
const Tensor& t,
345+
SupportedTensorDtypes dtypes) {
346+
switch (dtypes) {
347+
case SupportedTensorDtypes::REALHBBF16:
348+
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
349+
case SupportedTensorDtypes::BOOL_OR_BYTE:
350+
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
351+
case SupportedTensorDtypes::SAME_AS_COMMON: {
352+
constexpr auto common_scalar_type =
353+
CppTypeToScalarType<CTYPE_COMMON>::value;
354+
ET_CHECK_MSG(
355+
t.scalar_type() == common_scalar_type,
356+
"Unhandled dtype %s for %s",
357+
::executorch::runtime::toString(common_scalar_type),
358+
op_name);
359+
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
360+
}
361+
}
362+
ET_CHECK(false);
363+
return nullptr;
364+
}
365+
366+
template <typename CTYPE_COMMON, const char* op_name>
367+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
368+
const Tensor& t,
369+
SupportedTensorDtypes dtypes) {
370+
switch (dtypes) {
371+
case SupportedTensorDtypes::REALHBBF16:
372+
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
373+
case SupportedTensorDtypes::BOOL_OR_BYTE:
374+
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
375+
t);
376+
case SupportedTensorDtypes::SAME_AS_COMMON: {
377+
constexpr auto common_scalar_type =
378+
CppTypeToScalarType<CTYPE_COMMON>::value;
379+
ET_CHECK_MSG(
380+
t.scalar_type() == common_scalar_type,
381+
"Unhandled dtype %s for %s",
382+
::executorch::runtime::toString(common_scalar_type),
383+
op_name);
384+
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
385+
}
386+
}
387+
ET_CHECK(false);
388+
return nullptr;
389+
}
390+
} // namespace internal
391+
313392
/**
314393
* Useful for binary elementwise operators. For each element of the inputs,
315394
* perform a computation and write to the corresponding element of the output.
@@ -356,33 +435,45 @@ inline void apply_binary_elementwise_fn(
356435
*
357436
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
358437
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
359-
* are passed as CTYPE_COMMON. We require compute_fun to return
360-
* CTYPE_COMMON, and we require loading conversion functions from each
361-
* input type to CTYPE_COMMON and a storing conversion from
362-
* CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
363-
* must take a void* pointing to an element of the corresponding
364-
* tensor, load that element, and convert it to CTYPE_COMMON. The
365-
* storing conversion function must have the signature
366-
* void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
367-
* and store it to the given location.
438+
* are passed as CTYPE_COMMON.
439+
*
440+
* Each tensor's supported dtypes set must be provided. The tensor
441+
* will be checked to ensure that its dtype falls into that set.
442+
*
443+
* op_name is used to support dtype selective build, as with the
444+
* ET_SWITCH family of macros. Note: because of C++17 quirks, you
445+
* can't pass a string literal for op_name. Instead, you should do the
446+
* following:
447+
*
448+
* static constexpr const char op_name[] = "my_op";
449+
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
368450
*/
369-
template <typename CTYPE_COMMON, typename Op>
451+
template <typename CTYPE_COMMON, const char* op_name, typename Op>
370452
inline void apply_ternary_elementwise_fn(
371453
const Op& compute_fun,
372454
const Tensor& a,
455+
SupportedTensorDtypes a_dtypes,
373456
const Tensor& b,
457+
SupportedTensorDtypes b_dtypes,
374458
const Tensor& c,
459+
SupportedTensorDtypes c_dtypes,
375460
const Tensor& out,
376-
CTYPE_COMMON (*load_a_to_common)(const void*),
377-
CTYPE_COMMON (*load_b_to_common)(const void*),
378-
CTYPE_COMMON (*load_c_to_common)(const void*),
379-
void (*store_common_to_out)(CTYPE_COMMON, void*)) {
461+
SupportedTensorDtypes out_dtypes) {
380462
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
381463
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
382464
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
383465
const bool any_is_broadcasted =
384466
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);
385467

468+
const auto load_a_to_common =
469+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
470+
const auto load_b_to_common =
471+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
472+
const auto load_c_to_common =
473+
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
474+
const auto store_common_to_out =
475+
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
476+
out, out_dtypes);
386477
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
387478
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
388479
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());

0 commit comments

Comments
 (0)