Skip to content

Commit 0225671

Browse files
manuelcandaleskirklandsign
authored andcommitted
Complex Support: diagonal_copy
Differential Revision: D72189747 Pull Request resolved: #9777
1 parent c568c38 commit 0225671

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

kernels/portable/cpu/op_diagonal_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ Tensor& diagonal_copy_out(
100100

101101
constexpr auto name = "diagonal_copy.out";
102102

103-
ET_SWITCH_REALHBBF16_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
103+
ET_SWITCH_ALL_TYPES(in.scalar_type(), ctx, name, CTYPE, [&] {
104104
diagonal_copy_impl<CTYPE>(in, offset, dim1, dim2, out);
105105
});
106106

kernels/test/op_diagonal_copy_test.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,38 @@ class OpDiagonalCopyOutTest : public ::testing::Test {
5050
op_diagonal_copy_out(input, 1, 1, 0, out);
5151
EXPECT_TENSOR_CLOSE(out, out_expected);
5252
}
53+
54+
template <typename CTYPE, ScalarType DTYPE>
55+
void run_2d_complex_dtype() {
56+
TensorFactory<DTYPE> tf;
57+
constexpr auto REAL_DTYPE = executorch::runtime::toRealValueType(DTYPE);
58+
using REAL_CTYPE =
59+
typename executorch::runtime::ScalarTypeToCppType<REAL_DTYPE>::type;
60+
Tensor input = tf.make(
61+
{3, 4},
62+
{CTYPE{REAL_CTYPE(1), REAL_CTYPE(1)},
63+
CTYPE{REAL_CTYPE(2), REAL_CTYPE(2)},
64+
CTYPE{REAL_CTYPE(3), REAL_CTYPE(3)},
65+
CTYPE{REAL_CTYPE(4), REAL_CTYPE(4)},
66+
CTYPE{REAL_CTYPE(5), REAL_CTYPE(5)},
67+
CTYPE{REAL_CTYPE(6), REAL_CTYPE(6)},
68+
CTYPE{REAL_CTYPE(7), REAL_CTYPE(7)},
69+
CTYPE{REAL_CTYPE(8), REAL_CTYPE(8)},
70+
CTYPE{REAL_CTYPE(9), REAL_CTYPE(9)},
71+
CTYPE{REAL_CTYPE(10), REAL_CTYPE(10)},
72+
CTYPE{REAL_CTYPE(11), REAL_CTYPE(11)},
73+
CTYPE{REAL_CTYPE(12), REAL_CTYPE(12)}});
74+
Tensor out = tf.make(
75+
{2},
76+
{CTYPE{REAL_CTYPE(0), REAL_CTYPE(0)},
77+
CTYPE{REAL_CTYPE(0), REAL_CTYPE(0)}});
78+
Tensor out_expected = tf.make(
79+
{2},
80+
{CTYPE{REAL_CTYPE(5), REAL_CTYPE(5)},
81+
CTYPE{REAL_CTYPE(10), REAL_CTYPE(10)}});
82+
op_diagonal_copy_out(input, 1, 1, 0, out);
83+
EXPECT_TENSOR_CLOSE(out, out_expected);
84+
}
5385
};
5486

5587
TEST_F(OpDiagonalCopyOutTest, SmokeTest2D) {
@@ -58,6 +90,13 @@ TEST_F(OpDiagonalCopyOutTest, SmokeTest2D) {
5890
#undef TEST_ENTRY
5991
}
6092

93+
TEST_F(OpDiagonalCopyOutTest, ComplexSmokeTest2D) {
94+
#define TEST_ENTRY(ctype, dtype) \
95+
run_2d_complex_dtype<ctype, ScalarType::dtype>();
96+
ET_FORALL_COMPLEXH_TYPES(TEST_ENTRY);
97+
#undef TEST_ENTRY
98+
}
99+
61100
TEST_F(OpDiagonalCopyOutTest, SmokeTest3D) {
62101
TensorFactory<ScalarType::Float> tfFloat;
63102

0 commit comments

Comments
 (0)