|
| 1 | +// RUN: %clangxx -fsycl %s -o %t.out |
| 2 | +// RUN: %t.out |
| 3 | + |
| 4 | +// "Hello world" bfloat16 test which checks conversion algorithms on host. |
| 5 | + |
| 6 | +#include <sycl/sycl.hpp> |
| 7 | + |
| 8 | +#include <cstdint> |
| 9 | +#include <type_traits> |
| 10 | + |
| 11 | +template <size_t Size> |
| 12 | +using get_uint_type_of_size = typename std::conditional_t< |
| 13 | + Size == 1, uint8_t, |
| 14 | + std::conditional_t< |
| 15 | + Size == 2, uint16_t, |
| 16 | + std::conditional_t<Size == 4, uint32_t, |
| 17 | + std::conditional_t<Size == 8, uint64_t, void>>>>; |
| 18 | + |
| 19 | +using bfloat16 = sycl::ext::oneapi::experimental::bfloat16; |
| 20 | +using Bfloat16StorageT = get_uint_type_of_size<sizeof(bfloat16)>; |
| 21 | + |
| 22 | +bool test(float Val, Bfloat16StorageT Bits) { |
| 23 | + std::cout << "Value: " << Val << " Bits: " << std::hex << "0x" << Bits |
| 24 | + << std::dec << "...\n"; |
| 25 | + bool Passed = true; |
| 26 | + { |
| 27 | + std::cout << " float -> bfloat16 conversion ..."; |
| 28 | + Bfloat16StorageT RawVal = sycl::bit_cast<Bfloat16StorageT>(bfloat16(Val)); |
| 29 | + bool Res = (RawVal == Bits); |
| 30 | + Passed &= Res; |
| 31 | + |
| 32 | + if (Res) { |
| 33 | + std::cout << "passed\n"; |
| 34 | + } else { |
| 35 | + std::cout << "failed. " << std::hex << "0x" << RawVal << " != " |
| 36 | + << "0x" << Bits << "(gold)\n" |
| 37 | + << std::dec; |
| 38 | + } |
| 39 | + } |
| 40 | + { |
| 41 | + std::cout << " bfloat16 -> float conversion ..."; |
| 42 | + float NewVal = static_cast<float>(sycl::bit_cast<bfloat16>(Bits)); |
| 43 | + bool Res = (NewVal == Val); |
| 44 | + Passed &= Res; |
| 45 | + |
| 46 | + if (Res) { |
| 47 | + std::cout << "passed\n"; |
| 48 | + } else { |
| 49 | + std::cout << "failed. " << Val << "(Gold) != " << NewVal << "\n"; |
| 50 | + } |
| 51 | + } |
| 52 | + return Passed; |
| 53 | +} |
| 54 | + |
| 55 | +int main() { |
| 56 | + bool passed = true; |
| 57 | + passed &= test(3.140625f, 0x4049); |
| 58 | + std::cout << (passed ? "Test Passed\n" : "Test FAILED\n"); |
| 59 | + return (passed ? 0 : 1); |
| 60 | +} |
0 commit comments