Skip to content

[ExecuTorch] Just pass SupportedTensorDtypes for each tensor to apply_ternary_elementwise_fn #5834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ Tensor& clamp_tensor_out(
static constexpr const char op_name[] = "clamp.Tensor_out";

ET_SWITCH_REALHB_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
apply_ternary_elementwise_fn<CTYPE_COMMON>(
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
[has_min, has_max](
const CTYPE_COMMON val_in,
const CTYPE_COMMON val_min,
Expand All @@ -230,13 +230,13 @@ Tensor& clamp_tensor_out(
return val_out;
},
in,
SupportedTensorDtypes::REALHBBF16,
min,
SupportedTensorDtypes::REALHBBF16,
max,
SupportedTensorDtypes::REALHBBF16,
out,
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(in),
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(min),
get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(max),
get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(out));
SupportedTensorDtypes::REALHBBF16);
});

return out;
Expand Down
32 changes: 14 additions & 18 deletions kernels/portable/cpu/op_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,25 @@ Tensor& where_out(
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out);

constexpr auto name = "where.self_out";
static constexpr const char op_name[] = "where.self_out";

ET_CHECK_MSG(
cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
"Unhandled dtype %s for where.self_out",
torch::executor::toString(cond_type));
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_OUT =
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
apply_ternary_elementwise_fn<CTYPE_OUT>(
[](const CTYPE_OUT val_a,
const CTYPE_OUT val_b,
const CTYPE_OUT val_c) { return val_c ? val_a : val_b; },
a,
b,
cond,
out,
internal::load_and_convert<CTYPE_OUT, CTYPE_A>,
internal::load_and_convert<CTYPE_OUT, CTYPE_B>,
internal::load_and_convert<CTYPE_OUT, uint8_t>,
internal::convert_and_store<CTYPE_OUT, CTYPE_OUT>);
});
ET_SWITCH_REALHBBF16_TYPES(common_type, ctx, op_name, CTYPE_COMMON, [&]() {
apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>(
[](const CTYPE_COMMON val_a,
const CTYPE_COMMON val_b,
const CTYPE_COMMON val_c) { return val_c ? val_a : val_b; },
a,
SupportedTensorDtypes::REALHBBF16,
b,
SupportedTensorDtypes::REALHBBF16,
cond,
SupportedTensorDtypes::BOOL_OR_BYTE,
out,
SupportedTensorDtypes::SAME_AS_COMMON);
});

return out;
Expand Down
121 changes: 106 additions & 15 deletions kernels/portable/cpu/util/broadcast_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ template <typename To, typename From>
void convert_and_store(From f, void* dst) {
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
}
} // namespace internal

template <typename CTYPE_COMMON>
using load_to_common_fn = CTYPE_COMMON (*)(const void*);
Expand All @@ -296,6 +295,17 @@ load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
const Tensor& t) {
CTYPE_COMMON (*result)(const void*) = nullptr;
ET_SWITCH_TWO_TYPES(
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
});
return result;
}

template <typename CTYPE_COMMON>
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);

Expand All @@ -310,6 +320,75 @@ get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
return result;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON>
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
void (*result)(CTYPE_COMMON, void*) = nullptr;
ET_SWITCH_TWO_TYPES(
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
});
return result;
}
} // namespace internal

enum class SupportedTensorDtypes {
REALHBBF16,
BOOL_OR_BYTE,
SAME_AS_COMMON,
};

namespace internal {
template <typename CTYPE_COMMON, const char* op_name>
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
const Tensor& t,
SupportedTensorDtypes dtypes) {
switch (dtypes) {
case SupportedTensorDtypes::REALHBBF16:
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::SAME_AS_COMMON: {
constexpr auto common_scalar_type =
CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
}
}
ET_CHECK(false);
return nullptr;
}

template <typename CTYPE_COMMON, const char* op_name>
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
const Tensor& t,
SupportedTensorDtypes dtypes) {
switch (dtypes) {
case SupportedTensorDtypes::REALHBBF16:
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
case SupportedTensorDtypes::BOOL_OR_BYTE:
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
t);
case SupportedTensorDtypes::SAME_AS_COMMON: {
constexpr auto common_scalar_type =
CppTypeToScalarType<CTYPE_COMMON>::value;
ET_CHECK_MSG(
t.scalar_type() == common_scalar_type,
"Unhandled dtype %s for %s",
::executorch::runtime::toString(common_scalar_type),
op_name);
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
}
}
ET_CHECK(false);
return nullptr;
}
} // namespace internal

/**
* Useful for binary elementwise operators. For each element of the inputs,
* perform a computation and write to the corresponding element of the output.
Expand Down Expand Up @@ -356,33 +435,45 @@ inline void apply_binary_elementwise_fn(
*
* In order to mitigate build time cost (straightforwardly |CTYPE_A| *
* |CTYPE_B| * |CTYPE_C| * |CTYPE_OUT|), all arguments to compute_fun
* are passed as CTYPE_COMMON. We require compute_fun to return
* CTYPE_COMMON, and we require loading conversion functions from each
* input type to CTYPE_COMMON and a storing conversion from
* CTYPE_COMMON to CTYPE_OUT be provided. Each conversion function
* must take a void* pointing to an element of the corresponding
* tensor, load that element, and convert it to CTYPE_COMMON. The
* storing conversion function must have the signature
* void(CTYPE_COMMON, void*), convert the given element to CTYPE_OUT,
* and store it to the given location.
* are passed as CTYPE_COMMON.
*
* Each tensor's supported dtypes set must be provided. The tensor
* will be checked to ensure that its dtype falls into that set.
*
* op_name is used to support dtype selective build, as with the
* ET_SWITCH family of macros. Note: because of C++17 quirks, you
* can't pass a string literal for op_name. Instead, you should do the
* following:
*
* static constexpr const char op_name[] = "my_op";
* apply_ternary_elementwise_fn<CTYPE_COMMON, op_name>.
*/
template <typename CTYPE_COMMON, typename Op>
template <typename CTYPE_COMMON, const char* op_name, typename Op>
inline void apply_ternary_elementwise_fn(
const Op& compute_fun,
const Tensor& a,
SupportedTensorDtypes a_dtypes,
const Tensor& b,
SupportedTensorDtypes b_dtypes,
const Tensor& c,
SupportedTensorDtypes c_dtypes,
const Tensor& out,
CTYPE_COMMON (*load_a_to_common)(const void*),
CTYPE_COMMON (*load_b_to_common)(const void*),
CTYPE_COMMON (*load_c_to_common)(const void*),
void (*store_common_to_out)(CTYPE_COMMON, void*)) {
SupportedTensorDtypes out_dtypes) {
const bool a_is_broadcasted = !out.sizes().equals(a.sizes());
const bool b_is_broadcasted = !out.sizes().equals(b.sizes());
const bool c_is_broadcasted = !out.sizes().equals(c.sizes());
const bool any_is_broadcasted =
(a_is_broadcasted || b_is_broadcasted || c_is_broadcasted);

const auto load_a_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(a, a_dtypes);
const auto load_b_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(b, b_dtypes);
const auto load_c_to_common =
internal::get_load_to_common_fn<CTYPE_COMMON, op_name>(c, c_dtypes);
const auto store_common_to_out =
internal::get_store_common_to_tensor_fn<CTYPE_COMMON, op_name>(
out, out_dtypes);
const char* const data_a = reinterpret_cast<const char*>(a.const_data_ptr());
const char* const data_b = reinterpret_cast<const char*>(b.const_data_ptr());
const char* const data_c = reinterpret_cast<const char*>(c.const_data_ptr());
Expand Down
Loading