Skip to content

Commit f1b19ae

Browse files
committed
Improve bfloat16::to_float not to use std::memcpy, fix test.
1 parent 94f982a commit f1b19ae

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

sycl/include/sycl/ext/oneapi/experimental/bfloat16.hpp

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
#if !defined(__SYCL_DEVICE_ONLY__)
1515
#include <cmath>
16-
#include <cstring> // for std::memcpy
1716
#endif
1817

1918
namespace sycl {
@@ -64,11 +63,13 @@ class bfloat16 {
6463
return __spirv_ConvertBF16ToFINTEL(a);
6564
#endif
6665
#else
67-
uint32_t bits = a;
68-
bits <<= 16;
69-
float res;
70-
std::memcpy(&res, &bits, sizeof(res));
71-
return res;
66+
union {
67+
uint32_t bits;
68+
float res;
69+
} val;
70+
val.bits = a;
71+
val.bits <<= 16;
72+
return val.res;
7273
#endif
7374
}
7475

sycl/test/extensions/bfloat16_host.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,17 @@ int main() {
7373
Success &= check_bf16_from_float(std::numeric_limits<float>::quiet_NaN(),
7474
std::stoi("1111111111000001", nullptr, 2));
7575

76+
// see https://float.exposed/b0xffff
7677
Success &= check_bf16_to_float(
7778
0, bitsToFloatConv(std::string("00000000000000000000000000000000")));
7879
Success &= check_bf16_to_float(
79-
1, bitsToFloatConv(std::string("01000111100000000000000000000000")));
80+
1, bitsToFloatConv(std::string("00000000000000010000000000000000")));
8081
Success &= check_bf16_to_float(
81-
42, bitsToFloatConv(std::string("01001010001010000000000000000000")));
82+
42, bitsToFloatConv(std::string("00000000001010100000000000000000")));
8283
Success &= check_bf16_to_float(
83-
std::numeric_limits<uint16_t>::max(),
84-
bitsToFloatConv(std::string("01001111011111111111111100000000")));
84+
// std::numeric_limits<uint16_t>::max() - 0xffff is bfloat16 -Nan and
85+
// -Nan == -Nan check in check_bf16_to_float would fail, so use not Nan:
86+
65407, bitsToFloatConv(std::string("11111111011111110000000000000000")));
8587
if (!Success)
8688
return -1;
8789
return 0;

0 commit comments

Comments
 (0)