Skip to content

Commit 9baa2df

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Enable Half for scalar_tensor, full_like, where.self_out (#2663)
Summary: Pull Request resolved: #2663 Reviewed By: digantdesai Differential Revision: D55334158 fbshipit-source-id: a858e576b2208e8a078757a29d2899963336ef61
1 parent cde514c commit 9baa2df

File tree

7 files changed

+97
-57
lines changed

7 files changed

+97
-57
lines changed

backends/xnnpack/test/models/llama2_et_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ class TestLlama2ETExample(unittest.TestCase):
1616
def test_f32(self):
1717
self._test()
1818

19-
@unittest.skip("T183420542: Add proper fp16 support.")
2019
def test_f16(self):
2120
self._test(torch.float16)
2221

kernels/portable/cpu/op_full_like.cpp

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,20 @@ Tensor& full_like_out(
4545
ScalarType val_type = utils::get_scalar_dtype(fill_value);
4646
ScalarType out_type = out.scalar_type();
4747

48-
ET_SWITCH_REAL_TYPES_AND(
49-
Bool, val_type, ctx, "full_like.out", CTYPE_VAL, [&] {
50-
CTYPE_VAL val;
51-
utils::extract_scalar(fill_value, &val);
48+
constexpr auto name = "scalar_tensor.out";
5249

53-
ET_SWITCH_REAL_TYPES_AND(
54-
Bool, out_type, ctx, "full_like.out", CTYPE_OUT, [&] {
55-
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
56-
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
57-
for (size_t i = 0; i < out.numel(); ++i) {
58-
data_out[i] = val_casted;
59-
}
60-
});
61-
});
50+
ET_SWITCH_REALB_TYPES(val_type, ctx, name, CTYPE_VAL, [&] {
51+
CTYPE_VAL val;
52+
utils::extract_scalar(fill_value, &val);
53+
54+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&] {
55+
CTYPE_OUT val_casted = static_cast<CTYPE_OUT>(val);
56+
auto data_out = out.mutable_data_ptr<CTYPE_OUT>();
57+
for (size_t i = 0; i < out.numel(); ++i) {
58+
data_out[i] = val_casted;
59+
}
60+
});
61+
});
6262

6363
return out;
6464
}

kernels/portable/cpu/op_scalar_tensor.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ Tensor& scalar_tensor_out(RuntimeContext& ctx, const Scalar& s, Tensor& out) {
2222
ScalarType s_type = utils::get_scalar_dtype(s);
2323
ScalarType out_type = out.scalar_type();
2424

25-
ET_SWITCH_REAL_TYPES_AND(
26-
Bool, out_type, ctx, "scalar_tensor.out", CTYPE, [&]() {
27-
ET_SWITCH_SCALAR_OBJ_TYPES(
28-
s_type, ctx, "scalar_tensor.out", CTYPE_S, [&]() {
29-
CTYPE_S val_s;
30-
utils::extract_scalar(s, &val_s);
31-
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
32-
});
33-
});
25+
constexpr auto name = "scalar_tensor.out";
26+
27+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE, [&]() {
28+
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() {
29+
CTYPE_S val_s;
30+
utils::extract_scalar(s, &val_s);
31+
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
32+
});
33+
});
3434

3535
return out;
3636
}

kernels/portable/cpu/op_where.cpp

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -35,41 +35,28 @@ Tensor& where_out(
3535
InvalidArgument,
3636
out);
3737

38-
ET_SWITCH_TWO_TYPES(
39-
Bool, Byte, cond_type, ctx, "where.self_out", CTYPE_COND, [&]() {
40-
ET_SWITCH_REAL_TYPES_AND(
41-
Bool, a_type, ctx, "where.self_out", CTYPE_A, [&]() {
42-
ET_SWITCH_REAL_TYPES_AND(
43-
Bool, b_type, ctx, "where.self_out", CTYPE_B, [&]() {
44-
ET_SWITCH_REAL_TYPES_AND(
45-
Bool,
46-
out_type,
47-
ctx,
48-
"where.self_out",
49-
CTYPE_OUT,
50-
[&]() {
51-
apply_ternary_elementwise_fn<
52-
CTYPE_A,
53-
CTYPE_B,
54-
CTYPE_COND,
55-
CTYPE_OUT>(
56-
[](const CTYPE_A val_a,
57-
const CTYPE_B val_b,
58-
const CTYPE_COND val_c) {
59-
CTYPE_OUT a_casted =
60-
static_cast<CTYPE_OUT>(val_a);
61-
CTYPE_OUT b_casted =
62-
static_cast<CTYPE_OUT>(val_b);
63-
return val_c ? a_casted : b_casted;
64-
},
65-
a,
66-
b,
67-
cond,
68-
out);
69-
});
70-
});
71-
});
38+
constexpr auto name = "where.self_out";
39+
40+
ET_SWITCH_TWO_TYPES(Bool, Byte, cond_type, ctx, name, CTYPE_COND, [&]() {
41+
ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
42+
ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
43+
ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
44+
apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_COND, CTYPE_OUT>(
45+
[](const CTYPE_A val_a,
46+
const CTYPE_B val_b,
47+
const CTYPE_COND val_c) {
48+
CTYPE_OUT a_casted = static_cast<CTYPE_OUT>(val_a);
49+
CTYPE_OUT b_casted = static_cast<CTYPE_OUT>(val_b);
50+
return val_c ? a_casted : b_casted;
51+
},
52+
a,
53+
b,
54+
cond,
55+
out);
56+
});
7257
});
58+
});
59+
});
7360

7461
return out;
7562
}

kernels/test/op_full_like_test.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,25 @@ TEST_F(OpFullLikeTest, DynamicShapeUnbound) {
187187
Tensor ret = op_full_like_out(x, Scalar(3.0), MemoryFormat::Contiguous, out);
188188
EXPECT_TENSOR_CLOSE(out, expected_result);
189189
}
190+
191+
TEST_F(OpFullLikeTest, HalfSupport) {
192+
TensorFactory<ScalarType::Half> tf;
193+
optional<MemoryFormat> memory_format;
194+
Tensor in = tf.ones({2, 3});
195+
Tensor out = tf.zeros({2, 3});
196+
197+
op_full_like_out(in, false, memory_format, out);
198+
EXPECT_TENSOR_CLOSE(out, tf.full({2, 3}, 0));
199+
200+
op_full_like_out(in, true, memory_format, out);
201+
EXPECT_TENSOR_CLOSE(out, tf.full({2, 3}, 1));
202+
203+
op_full_like_out(in, 7, memory_format, out);
204+
EXPECT_TENSOR_CLOSE(out, tf.full({2, 3}, 7));
205+
206+
op_full_like_out(in, 2.5, memory_format, out);
207+
EXPECT_TENSOR_CLOSE(out, tf.full({2, 3}, 2.5));
208+
209+
op_full_like_out(in, INFINITY, memory_format, out);
210+
EXPECT_TENSOR_CLOSE(out, tf.full({2, 3}, INFINITY));
211+
}

kernels/test/op_scalar_tensor_test.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,23 @@ TEST_F(OpScalarTensorOutTest, InvalidOutShapeFails) {
111111
Tensor out = tf.ones(sizes);
112112
ET_EXPECT_KERNEL_FAILURE(context_, op_scalar_tensor_out(7, out));
113113
}
114+
115+
TEST_F(OpScalarTensorOutTest, HalfSupport) {
116+
TensorFactory<ScalarType::Half> tf;
117+
Tensor out = tf.zeros({});
118+
119+
op_scalar_tensor_out(false, out);
120+
EXPECT_TENSOR_CLOSE(out, tf.make({}, {0}));
121+
122+
op_scalar_tensor_out(true, out);
123+
EXPECT_TENSOR_CLOSE(out, tf.make({}, {1}));
124+
125+
op_scalar_tensor_out(7, out);
126+
EXPECT_TENSOR_CLOSE(out, tf.make({}, {7}));
127+
128+
op_scalar_tensor_out(2.5, out);
129+
EXPECT_TENSOR_CLOSE(out, tf.make({}, {2.5}));
130+
131+
op_scalar_tensor_out(INFINITY, out);
132+
EXPECT_TENSOR_CLOSE(out, tf.make({}, {INFINITY}));
133+
}

kernels/test/op_where_test.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,15 @@ TEST_F(OpWhereOutTest, DynamicShapeUnbound) {
449449
test_dynamic_shape(
450450
{1, 1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
451451
}
452+
453+
TEST_F(OpWhereOutTest, HalfSupport) {
454+
TensorFactory<ScalarType::Bool> tb;
455+
TensorFactory<ScalarType::Half> tf;
456+
Tensor cond = tb.make({2, 3}, {true, false, true, false, true, false});
457+
Tensor a = tf.full({2, 3}, 1.5);
458+
Tensor b = tf.full({2, 3}, 2.5);
459+
Tensor out = tf.zeros({2, 3});
460+
461+
op_where_self_out(cond, a, b, out);
462+
EXPECT_TENSOR_CLOSE(out, tf.make({2, 3}, {1.5, 2.5, 1.5, 2.5, 1.5, 2.5}));
463+
}

0 commit comments

Comments
 (0)