Skip to content

Commit 9817e0a

Browse files
committed
Implement portable abs for complex input
Absolute value of a complex number is straightforward enough. Had to fix a couple other things because this is (I think) the first use of complex types in ExecuTorch. Differential Revision: [D69058051](https://our.internmc.facebook.com/intern/diff/D69058051/) ghstack-source-id: 264405527 Pull Request resolved: #8146
1 parent a5c7609 commit 9817e0a

File tree

3 files changed

+90
-19
lines changed

3 files changed

+90
-19
lines changed

kernels/portable/cpu/op_abs.cpp

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,23 +27,48 @@ Tensor& abs_out(KernelRuntimeContext& ctx, const Tensor& in, Tensor& out) {
2727
out,
2828
"Failed to resize output tensor.");
2929

30-
ET_KERNEL_CHECK(ctx, tensors_have_same_dtype(in, out), InvalidArgument, out);
30+
const bool in_is_complex =
31+
executorch::runtime::isComplexType(in.scalar_type());
32+
ET_KERNEL_CHECK(
33+
ctx,
34+
in_is_complex || tensors_have_same_dtype(in, out),
35+
InvalidArgument,
36+
out);
3137
ET_KERNEL_CHECK(
3238
ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out);
3339

34-
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
35-
apply_unary_map_fn(
36-
[](const CTYPE val_in) {
37-
if (val_in < 0) {
38-
return static_cast<CTYPE>(-val_in);
39-
} else {
40-
return static_cast<CTYPE>(val_in);
41-
}
42-
},
43-
in.const_data_ptr<CTYPE>(),
44-
out.mutable_data_ptr<CTYPE>(),
45-
in.numel());
46-
});
40+
if (in_is_complex) {
41+
// NOTE: Elected not to add COMPLEXH to dtype_util.h for now
42+
// because I am not planning wide rollout of complex support; if
43+
// we do add SupportedTensorDtypes::COMPLEXH support, then we
44+
// should use it here.
45+
ET_SWITCH_COMPLEXH_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE_IN, [&] {
46+
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "abs.out", CTYPE_OUT, [&] {
47+
apply_unary_map_fn<CTYPE_IN, CTYPE_OUT>(
48+
[](const CTYPE_IN val_in) -> CTYPE_OUT {
49+
return sqrt(
50+
val_in.real_ * val_in.real_ + val_in.imag_ * val_in.imag_);
51+
},
52+
in.const_data_ptr<CTYPE_IN>(),
53+
out.mutable_data_ptr<CTYPE_OUT>(),
54+
in.numel());
55+
});
56+
});
57+
} else {
58+
ET_SWITCH_REALHBF16_TYPES(in.scalar_type(), ctx, "abs.out", CTYPE, [&] {
59+
apply_unary_map_fn(
60+
[](const CTYPE val_in) {
61+
if (val_in < 0) {
62+
return static_cast<CTYPE>(-val_in);
63+
} else {
64+
return static_cast<CTYPE>(val_in);
65+
}
66+
},
67+
in.const_data_ptr<CTYPE>(),
68+
out.mutable_data_ptr<CTYPE>(),
69+
in.numel());
70+
});
71+
}
4772

4873
return out;
4974
}

kernels/test/op_abs_test.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,39 @@ class OpAbsTest : public OperatorTest {
3838
EXPECT_TENSOR_EQ(out, ret);
3939
EXPECT_TENSOR_EQ(out, expected);
4040
}
41+
42+
template <typename CTYPE, ScalarType DTYPE>
43+
void run_complex_smoke_test() {
44+
TensorFactory<DTYPE> tf;
45+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
46+
TensorFactory<REAL_DTYPE> tf_out;
47+
using REAL_CTYPE =
48+
typename executorch::runtime::ScalarTypeToCppType<REAL_DTYPE>::type;
49+
Tensor in = tf.make(
50+
{1, 2},
51+
{CTYPE{REAL_CTYPE(3), REAL_CTYPE(4)},
52+
CTYPE{REAL_CTYPE(5), REAL_CTYPE(12)}});
53+
Tensor out = tf_out.zeros({1, 2});
54+
Tensor expected = tf_out.make({1, 2}, {5, 13});
55+
Tensor ret = op_abs_out(in, out);
56+
EXPECT_TENSOR_EQ(out, ret);
57+
EXPECT_TENSOR_CLOSE(out, expected);
58+
}
4159
};
4260

4361
TEST_F(OpAbsTest, SmokeTest) {
4462
#define RUN_SMOKE_TEST(ctype, dtype) run_smoke_test<ScalarType::dtype>();
4563
// TODO: cover all REALHBF16 types with generalized unary function test
4664
// harness.
4765
ET_FORALL_FLOATHBF16_TYPES(RUN_SMOKE_TEST);
66+
#undef RUN_SMOKE_TEST
67+
}
68+
69+
TEST_F(OpAbsTest, ComplexSmokeTest) {
70+
#define RUN_SMOKE_TEST(ctype, dtype) \
71+
run_complex_smoke_test<ctype, ScalarType::dtype>();
72+
ET_FORALL_COMPLEXH_TYPES(RUN_SMOKE_TEST);
73+
#undef RUN_SMOKE_TEST
4874
}
4975

5076
TEST_F(OpAbsTest, MemoryFormatCheck) {

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,14 @@ ET_FORALL_SCALAR_TYPES(SPECIALIZE_CppTypeToScalarType)
348348

349349
// In this context, "COMPLEX" means complex types based on primitive C types,
350350
// which is why ComplexHalf is not included.
351-
#define ET_FORALL_COMPLEX_TYPES(_) \
352-
_(::torch::executor::complex<float>, ComplexFloat) \
353-
_(::torch::executor::complex<double>, ComplexDouble)
351+
#define ET_FORALL_COMPLEX_TYPES(_) \
352+
_(::executorch::aten::complex<float>, ComplexFloat) \
353+
_(::executorch::aten::complex<double>, ComplexDouble)
354+
355+
#define ET_FORALL_COMPLEXH_TYPES(_) \
356+
_(::executorch::aten::complex<::executorch::aten::Half>, ComplexHalf) \
357+
_(::executorch::aten::complex<float>, ComplexFloat) \
358+
_(::executorch::aten::complex<double>, ComplexDouble)
354359

355360
//
356361
// Utility functions to retrieve metadata for a given ScalarType
@@ -593,7 +598,7 @@ inline bool isUnderlying(
593598
return type == ::executorch::runtime::toUnderlying(qtype);
594599
}
595600

596-
inline ::executorch::aten::ScalarType toRealValueType(
601+
inline constexpr ::executorch::aten::ScalarType toRealValueType(
597602
::executorch::aten::ScalarType t) {
598603
switch (t) {
599604
case ::executorch::aten::ScalarType::ComplexHalf:
@@ -607,7 +612,7 @@ inline ::executorch::aten::ScalarType toRealValueType(
607612
}
608613
}
609614

610-
inline ::executorch::aten::ScalarType toComplexType(
615+
inline constexpr ::executorch::aten::ScalarType toComplexType(
611616
::executorch::aten::ScalarType t) {
612617
switch (t) {
613618
case ::executorch::aten::ScalarType::BFloat16:
@@ -1060,6 +1065,14 @@ struct promote_types {
10601065
ET_INTERNAL_SWITCH_CASE( \
10611066
::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__)
10621067

1068+
#define ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, ...) \
1069+
ET_INTERNAL_SWITCH_CASE( \
1070+
::executorch::aten::ScalarType::ComplexHalf, CTYPE_ALIAS, __VA_ARGS__) \
1071+
ET_INTERNAL_SWITCH_CASE( \
1072+
::executorch::aten::ScalarType::ComplexFloat, CTYPE_ALIAS, __VA_ARGS__) \
1073+
ET_INTERNAL_SWITCH_CASE( \
1074+
::executorch::aten::ScalarType::ComplexDouble, CTYPE_ALIAS, __VA_ARGS__)
1075+
10631076
#define ET_INTERNAL_SWITCH_CASE_SCALAR_OBJ_TYPES(CTYPE_ALIAS, ...) \
10641077
ET_INTERNAL_SWITCH_CASE( \
10651078
::executorch::aten::ScalarType::Bool, CTYPE_ALIAS, __VA_ARGS__) \
@@ -1278,6 +1291,13 @@ struct promote_types {
12781291
NAME, \
12791292
ET_INTERNAL_SWITCH_CASE_COMPLEX_TYPES(CTYPE_ALIAS, __VA_ARGS__))
12801293

1294+
#define ET_SWITCH_COMPLEXH_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1295+
ET_INTERNAL_SWITCH( \
1296+
TYPE, \
1297+
CONTEXT, \
1298+
NAME, \
1299+
ET_INTERNAL_SWITCH_CASE_COMPLEXH_TYPES(CTYPE_ALIAS, __VA_ARGS__))
1300+
12811301
#define ET_SWITCH_SCALAR_OBJ_TYPES(TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
12821302
ET_INTERNAL_SWITCH( \
12831303
TYPE, \

0 commit comments

Comments
 (0)