Skip to content

Commit 6897210

Browse files
authored
Support Half/Bfloat for rand() and fill(). (#8123)
Summary: . Differential Revision: D68984778
1 parent 92e7dbd commit 6897210

File tree

2 files changed

+52
-5
lines changed

2 files changed

+52
-5
lines changed

extension/tensor/tensor_ptr_maker.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,11 @@ bool extract_scalar(executorch::aten::Scalar scalar, INT_T* out_val) {
3434

3535
template <
3636
typename FLOAT_T,
37-
typename std::enable_if<std::is_floating_point<FLOAT_T>::value, bool>::
38-
type = true>
37+
typename std::enable_if<
38+
std::is_floating_point_v<FLOAT_T> ||
39+
std::is_same_v<FLOAT_T, executorch::aten::BFloat16> ||
40+
std::is_same_v<FLOAT_T, executorch::aten::Half>,
41+
bool>::type = true>
3942
bool extract_scalar(executorch::aten::Scalar scalar, FLOAT_T* out_val) {
4043
double val;
4144
if (scalar.isFloatingPoint()) {
@@ -59,7 +62,7 @@ template <
5962
typename std::enable_if<std::is_same<BOOL_T, bool>::value, bool>::type =
6063
true>
6164
bool extract_scalar(executorch::aten::Scalar scalar, BOOL_T* out_val) {
62-
if (scalar.isIntegral(false)) {
65+
if (scalar.isIntegral(/*includeBool=*/false)) {
6366
*out_val = static_cast<bool>(scalar.to<int64_t>());
6467
return true;
6568
}
@@ -86,7 +89,7 @@ TensorPtr random_strided(
8689
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
8790
std::default_random_engine gen{std::random_device{}()};
8891

89-
ET_SWITCH_REALB_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
92+
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "random_strided", CTYPE, [&] {
9093
std::generate_n(tensor->mutable_data_ptr<CTYPE>(), tensor->numel(), [&]() {
9194
return static_cast<CTYPE>(distribution(gen));
9295
});
@@ -121,7 +124,7 @@ TensorPtr full_strided(
121124
executorch::aten::TensorShapeDynamism dynamism) {
122125
auto tensor =
123126
empty_strided(std::move(sizes), std::move(strides), type, dynamism);
124-
ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
127+
ET_SWITCH_REALHBBF16_TYPES(type, nullptr, "full_strided", CTYPE, [&] {
125128
CTYPE value;
126129
ET_EXTRACT_SCALAR(fill_value, value);
127130
std::fill(

extension/tensor/test/tensor_ptr_maker_test.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,20 @@ TEST_F(TensorPtrMakerTest, CreateFull) {
234234
EXPECT_EQ(tensor4->size(1), 5);
235235
EXPECT_EQ(tensor4->scalar_type(), executorch::aten::ScalarType::Double);
236236
EXPECT_EQ(tensor4->const_data_ptr<double>()[0], 11);
237+
238+
auto tensor5 = full({4, 5}, 13, executorch::aten::ScalarType::Half);
239+
EXPECT_EQ(tensor5->dim(), 2);
240+
EXPECT_EQ(tensor5->size(0), 4);
241+
EXPECT_EQ(tensor5->size(1), 5);
242+
EXPECT_EQ(tensor5->scalar_type(), executorch::aten::ScalarType::Half);
243+
EXPECT_EQ(tensor5->const_data_ptr<executorch::aten::Half>()[0], 13);
244+
245+
auto tensor6 = full({4, 5}, 15, executorch::aten::ScalarType::BFloat16);
246+
EXPECT_EQ(tensor6->dim(), 2);
247+
EXPECT_EQ(tensor6->size(0), 4);
248+
EXPECT_EQ(tensor6->size(1), 5);
249+
EXPECT_EQ(tensor6->scalar_type(), executorch::aten::ScalarType::BFloat16);
250+
EXPECT_EQ(tensor6->const_data_ptr<executorch::aten::BFloat16>()[0], 15);
237251
}
238252

239253
TEST_F(TensorPtrMakerTest, CreateScalar) {
@@ -363,6 +377,36 @@ TEST_F(TensorPtrMakerTest, CreateRandTensorWithDoubleType) {
363377
}
364378
}
365379

380+
TEST_F(TensorPtrMakerTest, CreateRandTensorWithHalfType) {
381+
auto tensor = rand({4, 5}, executorch::aten::ScalarType::Half);
382+
383+
EXPECT_EQ(tensor->dim(), 2);
384+
EXPECT_EQ(tensor->size(0), 4);
385+
EXPECT_EQ(tensor->size(1), 5);
386+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::Half);
387+
388+
for (auto i = 0; i < tensor->numel(); ++i) {
389+
auto val = tensor->const_data_ptr<executorch::aten::Half>()[i];
390+
EXPECT_GE(val, 0.0);
391+
EXPECT_LT(val, 1.0);
392+
}
393+
}
394+
395+
TEST_F(TensorPtrMakerTest, CreateRandTensorWithBFloatType) {
396+
auto tensor = rand({4, 5}, executorch::aten::ScalarType::BFloat16);
397+
398+
EXPECT_EQ(tensor->dim(), 2);
399+
EXPECT_EQ(tensor->size(0), 4);
400+
EXPECT_EQ(tensor->size(1), 5);
401+
EXPECT_EQ(tensor->scalar_type(), executorch::aten::ScalarType::BFloat16);
402+
403+
for (auto i = 0; i < tensor->numel(); ++i) {
404+
auto val = tensor->const_data_ptr<executorch::aten::BFloat16>()[i];
405+
EXPECT_GE(val, 0.0);
406+
EXPECT_LT(val, 1.0);
407+
}
408+
}
409+
366410
TEST_F(TensorPtrMakerTest, CreateRandnTensor) {
367411
auto tensor = randn({100, 100});
368412

0 commit comments

Comments
 (0)