Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 2fef19b

Browse files
authored
[SYCL] Add bfloat16 'hello world' host test. (#1189)
* [SYCL] Add bfloat16 'hello world' host test. Signed-off-by: Konstantin S Bobrovsky <[email protected]>
1 parent 86fc200 commit 2fef19b

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

SYCL/BFloat16/bfloat_hw.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: %clangxx -fsycl %s -o %t.out
2+
// RUN: %t.out
3+
4+
// "Hello world" bfloat16 test which checks conversion algorithms on host.
5+
6+
#include <sycl/sycl.hpp>
7+
8+
#include <cstdint>
9+
#include <type_traits>
10+
11+
template <size_t Size>
12+
using get_uint_type_of_size = typename std::conditional_t<
13+
Size == 1, uint8_t,
14+
std::conditional_t<
15+
Size == 2, uint16_t,
16+
std::conditional_t<Size == 4, uint32_t,
17+
std::conditional_t<Size == 8, uint64_t, void>>>>;
18+
19+
using bfloat16 = sycl::ext::oneapi::experimental::bfloat16;
20+
using Bfloat16StorageT = get_uint_type_of_size<sizeof(bfloat16)>;
21+
22+
bool test(float Val, Bfloat16StorageT Bits) {
23+
std::cout << "Value: " << Val << " Bits: " << std::hex << "0x" << Bits
24+
<< std::dec << "...\n";
25+
bool Passed = true;
26+
{
27+
std::cout << " float -> bfloat16 conversion ...";
28+
Bfloat16StorageT RawVal = sycl::bit_cast<Bfloat16StorageT>(bfloat16(Val));
29+
bool Res = (RawVal == Bits);
30+
Passed &= Res;
31+
32+
if (Res) {
33+
std::cout << "passed\n";
34+
} else {
35+
std::cout << "failed. " << std::hex << "0x" << RawVal << " != "
36+
<< "0x" << Bits << "(gold)\n"
37+
<< std::dec;
38+
}
39+
}
40+
{
41+
std::cout << " bfloat16 -> float conversion ...";
42+
float NewVal = static_cast<float>(sycl::bit_cast<bfloat16>(Bits));
43+
bool Res = (NewVal == Val);
44+
Passed &= Res;
45+
46+
if (Res) {
47+
std::cout << "passed\n";
48+
} else {
49+
std::cout << "failed. " << Val << "(Gold) != " << NewVal << "\n";
50+
}
51+
}
52+
return Passed;
53+
}
54+
55+
int main() {
56+
bool passed = true;
57+
passed &= test(3.140625f, 0x4049);
58+
std::cout << (passed ? "Test Passed\n" : "Test FAILED\n");
59+
return (passed ? 0 : 1);
60+
}

0 commit comments

Comments
 (0)