Skip to content

Reapply #9841: Migrate elementwise_util callers to the variants with out_dtypes in template arguments #10491

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
May 9, 2025
20 changes: 12 additions & 8 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,17 +52,19 @@ Tensor& add_out(

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[val_alpha](const auto val_a, const auto val_b) {
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
Expand Down Expand Up @@ -100,17 +102,19 @@ Tensor& add_scalar_out(
static constexpr const char op_name[] = "add.Scalar_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[b, alpha](const CTYPE_COMPUTE val_a) {
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[b, alpha](const auto val_a) {
CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
CTYPE_COMPUTE val_alpha = utils::scalar_to<CTYPE_COMPUTE>(alpha);
return val_a + val_alpha * val_b;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
Expand Down
10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_addmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,19 @@ Tensor& addmm_out(
n,
p);

utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
[alpha_val, beta_val](const CTYPE val_a, const CTYPE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[alpha_val, beta_val](const auto val_a, const auto val_b) {
return val_a * alpha_val + val_b * beta_val;
},
ctx,
out,
utils::SupportedTensorDtypes::REALHBF16,
in,
utils::SupportedTensorDtypes::REALHBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
}
});

Expand Down
10 changes: 6 additions & 4 deletions kernels/portable/cpu/op_atan2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,19 @@ Tensor& atan2_out(
static constexpr const char op_name[] = "atan2.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) {
return std::atan2(val_a, val_b);
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
Expand Down
18 changes: 12 additions & 6 deletions kernels/portable/cpu/op_clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,12 @@ Tensor& clamp_out(
static constexpr const char op_name[] = "clamp.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[has_min, min_opt, has_max, max_opt](const CTYPE_COMPUTE val_in) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(
Expand All @@ -150,8 +154,7 @@ Tensor& clamp_out(
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
Expand Down Expand Up @@ -210,11 +213,15 @@ Tensor& clamp_tensor_out(
static constexpr const char op_name[] = "clamp.Tensor_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_tritensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_tritensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[has_min, has_max](
const CTYPE_COMPUTE val_in,
const CTYPE_COMPUTE val_min,
const CTYPE_COMPUTE val_max) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE val_out = val_in;
if (has_min) {
val_out = utils::max_override(val_out, val_min);
Expand All @@ -231,8 +238,7 @@ Tensor& clamp_tensor_out(
utils::SupportedTensorDtypes::REALHBBF16,
max,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
Expand Down
16 changes: 10 additions & 6 deletions kernels/portable/cpu/op_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,17 @@ Tensor& copy_out(
std::memcpy(out.mutable_data_ptr(), src.const_data_ptr(), src.nbytes());
} else {
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy.out", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});
}

Expand Down Expand Up @@ -93,15 +95,17 @@ Tensor& copy_(
std::memcpy(in.mutable_data_ptr(), src.const_data_ptr(), in.nbytes());
} else {
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, "copy_", CTYPE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](ET_UNUSED const CTYPE _, const CTYPE val_src) { return val_src; },
ctx,
in,
utils::SupportedTensorDtypes::REALHBBF16,
src,
utils::SupportedTensorDtypes::REALHBBF16,
in,
utils::SupportedTensorDtypes::REALHBBF16);
in);
});
}

Expand Down
31 changes: 18 additions & 13 deletions kernels/portable/cpu/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,17 @@ Tensor& div_out(
static constexpr const char op_name[] = "div.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a / val_b;
},
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::FLOATHBF16>(
[](const auto val_a, const auto val_b) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::FLOATHBF16);
out);
});

return out;
Expand Down Expand Up @@ -122,9 +122,13 @@ Tensor& div_out_mode(
bool div_by_zero_error = false;

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[mode_is_trunc, &div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
Expand All @@ -146,8 +150,7 @@ Tensor& div_out_mode(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
Expand Down Expand Up @@ -188,13 +191,15 @@ Tensor& div_scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_b](const CTYPE_COMPUTE val_a) { return val_a / val_b; },
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[val_b](const auto val_a) { return val_a / val_b; },
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});

return out;
Expand Down
11 changes: 7 additions & 4 deletions kernels/portable/cpu/op_elu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,20 @@ Tensor& elu_out(
ET_EXTRACT_SCALAR(scale, math_scale);
ET_EXTRACT_SCALAR(input_scale, math_input_scale);
const auto negcoef = math_alpha * math_scale;
utils::apply_unitensor_elementwise_fn<CTYPE, op_name>(
[negcoef, math_scale, math_input_scale](auto x) {
utils::apply_unitensor_elementwise_fn<
CTYPE,
op_name,
utils::SupportedTensorDtypes::SAME_AS_COMMON>(
[negcoef, math_scale, math_input_scale](const auto x) {
// TODO: rewrite this to be vectorization-capable.
return MathT(x) <= MathT(0)
? std::expm1(MathT(x) * math_input_scale) * negcoef
: MathT(x) * math_scale;
},
ctx,
in,
utils::SupportedTensorDtypes::FLOATHBF16,
out,
utils::SupportedTensorDtypes::SAME_AS_COMMON);
out);
});
return out;
}
Expand Down
9 changes: 6 additions & 3 deletions kernels/portable/cpu/op_floor_divide.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,13 @@ Tensor& floor_divide_out(
bool div_by_zero_error = false;

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
Expand All @@ -69,8 +73,7 @@ Tensor& floor_divide_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
Expand Down
18 changes: 12 additions & 6 deletions kernels/portable/cpu/op_fmod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,13 @@ Tensor& fmod_Tensor_out(
bool div_by_zero_error = false;

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[&div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = 0;
if (is_integral_type<CTYPE_COMPUTE, /*includeBool=*/true>::value) {
if (val_b == 0) {
Expand All @@ -73,8 +77,7 @@ Tensor& fmod_Tensor_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

ET_KERNEL_CHECK_MSG(
Expand Down Expand Up @@ -131,16 +134,19 @@ Tensor& fmod_Scalar_out(

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_b = utils::scalar_to<CTYPE_COMPUTE>(b);
utils::apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_unitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBF16>(
[val_b](const CTYPE_COMPUTE val_a) {
// TODO: rewrite this to be vectorization-capable.
CTYPE_COMPUTE value = std::fmod(val_a, val_b);
return value;
},
ctx,
a,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBF16);
out);
});

return out;
Expand Down
8 changes: 5 additions & 3 deletions kernels/portable/cpu/op_maximum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ Tensor& maximum_out(
static constexpr const char op_name[] = "maximum.out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
utils::apply_bitensor_elementwise_fn<
CTYPE_COMPUTE,
op_name,
utils::SupportedTensorDtypes::REALHBBF16>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return utils::max_override(val_a, val_b);
},
Expand All @@ -54,8 +57,7 @@ Tensor& maximum_out(
utils::SupportedTensorDtypes::REALHBBF16,
b,
utils::SupportedTensorDtypes::REALHBBF16,
out,
utils::SupportedTensorDtypes::REALHBBF16);
out);
});

return out;
Expand Down
Loading
Loading