Skip to content

Commit 19bfbc3

Browse files
authored
Add complex dtype support to op_sum
Differential Revision: D73614059 Pull Request resolved: #10559
1 parent 087fe59 commit 19bfbc3

File tree

2 files changed

+128
-8
lines changed

2 files changed

+128
-8
lines changed

kernels/portable/cpu/op_sum.cpp

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,25 +52,56 @@ Tensor& sum_dim_out(
5252
}
5353
// @lint-ignore CLANGTIDY facebook-hte-CArray
5454
static constexpr const char op_name[] = "sum.IntList_out";
55-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
56-
ET_SWITCH_REALHBBF16_TYPES(out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
57-
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
55+
56+
if (executorch::runtime::isComplexType(in.scalar_type())) {
57+
ET_KERNEL_CHECK(
58+
ctx, in.scalar_type() == out.scalar_type(), InvalidArgument, out);
59+
60+
ET_SWITCH_COMPLEXH_TYPES(in.scalar_type(), ctx, op_name, CTYPE, [&] {
61+
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
5862
const bool success = parallel_for_each_reduce_over_dim_list_output_index(
5963
in, dim_list, out, [&](const auto begin, const auto end) {
6064
for (const auto out_ix : c10::irange(begin, end)) {
61-
CTYPE_OUT sum = 0;
65+
CTYPE sum(0, 0);
6266
if (plan.has_value()) {
63-
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
64-
[](CTYPE_IN v) { return static_cast<CTYPE_OUT>(v); },
65-
[](CTYPE_OUT outv, CTYPE_OUT acc) { return acc + outv; },
67+
sum = plan->execute<CTYPE, CTYPE>(
68+
[](CTYPE v) { return v; },
69+
[](CTYPE outv, CTYPE acc) { return acc + outv; },
6670
out_ix);
6771
}
6872
out_data[out_ix] = sum;
6973
}
7074
});
7175
ET_KERNEL_CHECK_MSG(ctx, success, Internal, , "parallel_for failed");
7276
});
73-
});
77+
} else {
78+
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, op_name, CTYPE_IN, [&] {
79+
ET_SWITCH_REALHBBF16_TYPES(
80+
out.scalar_type(), ctx, op_name, CTYPE_OUT, [&] {
81+
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
82+
const bool success =
83+
parallel_for_each_reduce_over_dim_list_output_index(
84+
in, dim_list, out, [&](const auto begin, const auto end) {
85+
for (const auto out_ix : c10::irange(begin, end)) {
86+
CTYPE_OUT sum = 0;
87+
if (plan.has_value()) {
88+
sum = plan->execute<CTYPE_IN, CTYPE_OUT>(
89+
[](CTYPE_IN v) {
90+
return static_cast<CTYPE_OUT>(v);
91+
},
92+
[](CTYPE_OUT outv, CTYPE_OUT acc) {
93+
return acc + outv;
94+
},
95+
out_ix);
96+
}
97+
out_data[out_ix] = sum;
98+
}
99+
});
100+
ET_KERNEL_CHECK_MSG(
101+
ctx, success, Internal, , "parallel_for failed");
102+
});
103+
});
104+
}
74105

75106
return out;
76107
}

kernels/test/op_sum_test.cpp

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,85 @@ class OpSumOutTest : public OperatorTest {
111111
self, optional_dim_list, /*keepdim=*/false, dtype, out));
112112
}
113113

114+
template <typename CTYPE, ScalarType DTYPE>
115+
void test_complex_dtype() {
116+
TensorFactory<DTYPE> tf;
117+
118+
Tensor self = tf.make(
119+
{2, 3, 2},
120+
{CTYPE(1, 1),
121+
CTYPE(2, 2),
122+
CTYPE(3, 3),
123+
CTYPE(4, 4),
124+
CTYPE(5, 5),
125+
CTYPE(6, 6),
126+
127+
CTYPE(7, 7),
128+
CTYPE(8, 8),
129+
CTYPE(9, 9),
130+
CTYPE(10, 10),
131+
CTYPE(11, 11),
132+
CTYPE(12, 12)});
133+
134+
Tensor out1 = tf.make(
135+
{2, 3, 1},
136+
{
137+
CTYPE(0, 0),
138+
CTYPE(0, 0),
139+
CTYPE(0, 0),
140+
CTYPE(0, 0),
141+
CTYPE(0, 0),
142+
CTYPE(0, 0),
143+
});
144+
int64_t dims_1[1] = {2};
145+
optional<ArrayRef<int64_t>> dim_list1{ArrayRef<int64_t>{dims_1, 1}};
146+
optional<ScalarType> dtype = DTYPE;
147+
148+
op_sum_intlist_out(self, dim_list1, true, dtype, out1);
149+
150+
Tensor expected1 = tf.make(
151+
{2, 3, 1},
152+
{CTYPE(3, 3),
153+
CTYPE(7, 7),
154+
CTYPE(11, 11),
155+
156+
CTYPE(15, 15),
157+
CTYPE(19, 19),
158+
CTYPE(23, 23)});
159+
160+
EXPECT_TENSOR_CLOSE(out1, expected1);
161+
162+
Tensor out2 = tf.make(
163+
{2, 1, 2},
164+
{
165+
CTYPE(0, 0),
166+
CTYPE(0, 0),
167+
CTYPE(0, 0),
168+
CTYPE(0, 0),
169+
});
170+
int64_t dims_2[1] = {1};
171+
optional<ArrayRef<int64_t>> dim_list2{ArrayRef<int64_t>{dims_2, 1}};
172+
173+
op_sum_intlist_out(self, dim_list2, true, dtype, out2);
174+
175+
Tensor expected2 = tf.make(
176+
{2, 1, 2}, {CTYPE(9, 9), CTYPE(12, 12), CTYPE(27, 27), CTYPE(30, 30)});
177+
EXPECT_TENSOR_CLOSE(out2, expected2);
178+
179+
Tensor out3 = tf.make(
180+
{1, 1, 1},
181+
{
182+
CTYPE(0, 0),
183+
});
184+
optional<ArrayRef<int64_t>> null_dim_list;
185+
186+
op_sum_intlist_out(self, null_dim_list, true, dtype, out3);
187+
188+
Tensor expected3 = tf.make({1, 1, 1}, {CTYPE(78, 78)});
189+
190+
EXPECT_TENSOR_CLOSE(out3, expected3);
191+
}
192+
114193
template <ScalarType IN_DTYPE, ScalarType OUT_DTYPE>
115194
void test_sum_dim_out_dtype() {
116195
TensorFactory<IN_DTYPE> tf_in;
@@ -366,6 +445,16 @@ TEST_F(OpSumOutTest, TypeConversionTest) {
366445
// clang-format on
367446
}
368447

448+
TEST_F(OpSumOutTest, AllComplexDtypesSupported) {
449+
#define TEST_ENTRY(ctype, dtype) test_complex_dtype<ctype, ScalarType::dtype>();
450+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
451+
ET_FORALL_COMPLEX_TYPES(TEST_ENTRY);
452+
} else {
453+
ET_FORALL_COMPLEXH_TYPES(TEST_ENTRY);
454+
}
455+
#undef TEST_ENTRY
456+
}
457+
369458
TEST_F(OpSumOutTest, InfinityAndNANTest) {
370459
TensorFactory<ScalarType::Float> tf_float;
371460
// clang-format off

0 commit comments

Comments
 (0)