Skip to content

Commit 9fafdb0

Browse files
Add Half support: full.out
Differential Revision: D61864622 Pull Request resolved: #4934
1 parent d91f612 commit 9fafdb0

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

kernels/portable/cpu/op_full.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@ Tensor& full_out(
3434
out,
3535
"Failed to resize output tensor.");
3636

37-
ET_SWITCH_REAL_TYPES_AND(Bool, val_type, ctx, "full.out", CTYPE_VAL, [&] {
37+
constexpr auto name = "full.out";
38+
39+
ET_SWITCH_SCALAR_OBJ_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
3840
CTYPE_VAL val;
3941
utils::extract_scalar(fill_value, &val);
4042

41-
ET_SWITCH_REAL_TYPES_AND(Bool, out_type, ctx, "full.out", CTYPE_OUT, [&] {
43+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
4244
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
4345
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
4446
for (size_t i = 0; i < out.numel(); ++i) {

kernels/test/op_full_test.cpp

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,25 @@ class OpFullOutTest : public OperatorTest {
3838
std::vector<int64_t> size_int64_t(size_int32_t.begin(), size_int32_t.end());
3939
auto aref = IntArrayRef(size_int64_t.data(), size_int64_t.size());
4040

41+
// Boolean Scalar
4142
// Before: `out` consists of 0s.
4243
Tensor out = tf.zeros(size_int32_t);
44+
// After: `out` consists of 1s.
45+
op_full_out(aref, true, out);
46+
EXPECT_TENSOR_EQ(out, tf.ones(size_int32_t));
4347

48+
// Integral Scalar
49+
// Before: `out` consists of 0s.
50+
out = tf.zeros(size_int32_t);
4451
// After: `out` consists of 1s.
4552
op_full_out(aref, 1, out);
53+
EXPECT_TENSOR_EQ(out, tf.ones(size_int32_t));
4654

55+
// Floating Point Scalar
56+
// Before: `out` consists of 0s.
57+
out = tf.zeros(size_int32_t);
58+
// After: `out` consists of 1s.
59+
op_full_out(aref, 1.0, out);
4760
EXPECT_TENSOR_EQ(out, tf.ones(size_int32_t));
4861
}
4962
};
@@ -57,4 +70,55 @@ class OpFullOutTest : public OperatorTest {
5770
test_ones_out<ScalarType::DTYPE>({2, 3, 4}); \
5871
}
5972

60-
ET_FORALL_REAL_TYPES(GENERATE_TEST)
73+
ET_FORALL_REALH_TYPES(GENERATE_TEST)
74+
75+
TEST_F(OpFullOutTest, ValueOverflow) {
76+
if (torch::executor::testing::SupportedFeatures::get()->is_aten) {
77+
GTEST_SKIP() << "ATen kernel doesn't handle overflow";
78+
}
79+
TensorFactory<ScalarType::Byte> tf;
80+
81+
std::vector<int64_t> sizes_int64_t_vec = {2, 3};
82+
std::vector<int32_t> sizes_in32_t_vec = {2, 3};
83+
auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size());
84+
85+
Tensor out = tf.zeros(sizes_in32_t_vec);
86+
87+
op_full_out(sizes, 1000, out);
88+
}
89+
90+
TEST_F(OpFullOutTest, HalfSupport) {
91+
TensorFactory<ScalarType::Half> tf;
92+
93+
std::vector<int64_t> sizes_int64_t_vec = {2, 3};
94+
std::vector<int32_t> sizes_in32_t_vec = {2, 3};
95+
auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size());
96+
97+
// Boolean Scalar
98+
Tensor out = tf.zeros(sizes_in32_t_vec);
99+
op_full_out(sizes, true, out);
100+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
101+
102+
// Integral Scalar
103+
out = tf.zeros(sizes_in32_t_vec);
104+
op_full_out(sizes, 1, out);
105+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
106+
107+
// Floating Point Scalar
108+
out = tf.zeros(sizes_in32_t_vec);
109+
op_full_out(sizes, 3.1415926535, out);
110+
EXPECT_TENSOR_EQ(out, tf.full(sizes_in32_t_vec, 3.1415926535));
111+
}
112+
113+
TEST_F(OpFullOutTest, ZeroDim) {
114+
TensorFactory<ScalarType::Half> tf;
115+
116+
std::vector<int64_t> sizes_int64_t_vec = {};
117+
std::vector<int32_t> sizes_in32_t_vec = {};
118+
auto sizes = IntArrayRef(sizes_int64_t_vec.data(), sizes_int64_t_vec.size());
119+
120+
// Boolean Scalar
121+
Tensor out = tf.zeros(sizes_in32_t_vec);
122+
op_full_out(sizes, true, out);
123+
EXPECT_TENSOR_EQ(out, tf.ones(sizes_in32_t_vec));
124+
}

0 commit comments

Comments
 (0)