Skip to content

Commit c9ac212

Browse files
authored
[ExecuTorch] support BF16 in op_to_copy
Differential Revision: D61981356 Pull Request resolved: #4976
1 parent b8a1899 commit c9ac212

File tree

6 files changed

+334
-15
lines changed

6 files changed

+334
-15
lines changed

kernels/portable/cpu/op_to_copy.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@ Tensor& to_copy_out(
4646
InvalidArgument,
4747
out);
4848

49-
ET_SWITCH_REALHB_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] {
50-
ET_SWITCH_REALHB_TYPES(out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] {
51-
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
52-
});
49+
ET_SWITCH_REALHBBF16_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] {
50+
ET_SWITCH_REALHBBF16_TYPES(
51+
out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] {
52+
_to_impl<CTYPE_IN, CTYPE_OUT>(self, out);
53+
});
5354
});
5455

5556
return out;

kernels/test/op_to_copy_test.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ typedef std::map<
3636
std::type_index,
3737
std::variant<
3838
std::vector<float>,
39-
std::vector<double>>>
39+
std::vector<double>,
40+
std::vector<exec_aten::Half>,
41+
std::vector<exec_aten::BFloat16>>>
4042
FloatingTypeToDataMap;
4143

4244
typedef std::map<
@@ -309,9 +311,9 @@ TEST_F(OpToTest, AllDtypesSupported) {
309311
ScalarType::OUTPUT_DTYPE>(test_cases);
310312

311313
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
312-
ET_FORALL_REAL_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
314+
ET_FORALL_REALHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
313315

314-
ET_FORALL_REAL_TYPES(TEST_ENTRY);
316+
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
315317

316318
#undef TEST_ENTRY
317319
#undef TEST_KERNEL
@@ -323,14 +325,14 @@ TEST_F(OpToTest, BoolTests) {
323325
#define TEST_TO_BOOL(INPUT_CTYPE, INPUT_DTYPE) \
324326
test_runner_to_bool<INPUT_CTYPE, ScalarType::INPUT_DTYPE>( \
325327
test_case_to_bool, result_to_bool);
326-
ET_FORALL_REAL_TYPES(TEST_TO_BOOL);
328+
ET_FORALL_REALHBF16_TYPES(TEST_TO_BOOL);
327329

328330
std::vector<uint8_t> test_case_from_bool = {true, true, false};
329331
std::vector<double> result_from_bool = {1.0, 1.0, 0};
330332
#define TEST_FROM_BOOL(OUTPUT_CTYPE, OUTPUT_DTYPE) \
331333
test_runner_from_bool<OUTPUT_CTYPE, ScalarType::OUTPUT_DTYPE>( \
332334
test_case_from_bool, result_from_bool);
333-
ET_FORALL_REAL_TYPES(TEST_FROM_BOOL);
335+
ET_FORALL_REALHBF16_TYPES(TEST_FROM_BOOL);
334336
}
335337

336338
TEST_F(OpToTest, NanInfSupported) {
@@ -349,9 +351,9 @@ TEST_F(OpToTest, NanInfSupported) {
349351
ScalarType::OUTPUT_DTYPE>(test_cases);
350352

351353
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
352-
ET_FORALL_FLOAT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
354+
ET_FORALL_FLOATHBF16_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
353355

354-
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
356+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
355357

356358
#undef TEST_ENTRY
357359
#undef TEST_KERNEL
@@ -381,6 +383,13 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
381383
-0.30919688936285893988};
382384
// clang-format on
383385

386+
std::vector<exec_aten::Half> half_data;
387+
std::vector<exec_aten::BFloat16> bf16_data;
388+
for (auto d : double_data) {
389+
half_data.emplace_back(d);
390+
bf16_data.emplace_back(d);
391+
}
392+
384393
std::vector<int64_t> int64_data = {
385394
-1, -4, 2, -2, 3, 3, -3, -4, 3, 3, 0, 2, 0, -1, 0};
386395
std::vector<int32_t> int32_data = {
@@ -394,6 +403,8 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
394403
FloatingTypeToDataMap floating_point_data;
395404
floating_point_data[typeid(float)] = float_data;
396405
floating_point_data[typeid(double)] = double_data;
406+
floating_point_data[typeid(exec_aten::Half)] = half_data;
407+
floating_point_data[typeid(exec_aten::BFloat16)] = bf16_data;
397408

398409
// Gathering all int data together for better traversial
399410
IntTypeToDataMap int_data;
@@ -412,7 +423,7 @@ TEST_F(OpToTest, HardcodeFloatConvertInt) {
412423
#define TEST_ENTRY(INPUT_CTYPE, INPUT_DTYPE) \
413424
ET_FORALL_INT_TYPES_WITH2(INPUT_CTYPE, INPUT_DTYPE, TEST_KERNEL);
414425

415-
ET_FORALL_FLOAT_TYPES(TEST_ENTRY);
426+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
416427
}
417428

418429
TEST_F(OpToTest, MismatchedSizesDie) {

runtime/core/exec_aten/exec_aten.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <c10/core/MemoryFormat.h> // @manual
1818
#include <c10/core/Scalar.h> // @manual
1919
#include <c10/util/ArrayRef.h> // @manual
20+
#include <c10/util/BFloat16-math.h> // @manual
2021
#include <c10/util/BFloat16.h> // @manual
2122
#include <c10/util/Half.h> // @manual
2223
#include <c10/util/Optional.h> // @manual
@@ -31,6 +32,7 @@
3132
#else // use executor
3233
#include <executorch/runtime/core/array_ref.h> // @manual
3334
#include <executorch/runtime/core/portable_type/bfloat16.h> // @manual
35+
#include <executorch/runtime/core/portable_type/bfloat16_math.h> // @manual
3436
#include <executorch/runtime/core/portable_type/complex.h> // @manual
3537
#include <executorch/runtime/core/portable_type/device.h> // @manual
3638
#include <executorch/runtime/core/portable_type/half.h> // @manual

runtime/core/exec_aten/testing_util/tensor_util.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
1717
#include <executorch/runtime/core/exec_aten/util/tensor_util.h>
1818

19+
using exec_aten::BFloat16;
20+
using exec_aten::Half;
1921
using exec_aten::ScalarType;
2022
using exec_aten::Tensor;
2123

@@ -32,9 +34,7 @@ namespace {
3234
* T must be a floating point type. Non-floating point data should be compared
3335
* directly.
3436
*/
35-
template <
36-
typename T,
37-
typename = std::enable_if_t<std::is_floating_point<T>::value>>
37+
template <typename T>
3838
bool data_is_close(
3939
const T* a,
4040
const T* b,
@@ -119,6 +119,20 @@ bool tensors_are_close(
119119
a.numel(),
120120
rtol,
121121
atol);
122+
} else if (a.scalar_type() == ScalarType::Half) {
123+
return data_is_close<Half>(
124+
a.const_data_ptr<Half>(),
125+
b.const_data_ptr<Half>(),
126+
a.numel(),
127+
rtol,
128+
atol);
129+
} else if (a.scalar_type() == ScalarType::BFloat16) {
130+
return data_is_close<BFloat16>(
131+
a.const_data_ptr<BFloat16>(),
132+
b.const_data_ptr<BFloat16>(),
133+
a.numel(),
134+
rtol,
135+
atol);
122136
} else {
123137
// Non-floating-point types can be compared bitwise.
124138
return memcmp(a.const_data_ptr(), b.const_data_ptr(), a.nbytes()) == 0;

0 commit comments

Comments
 (0)