Skip to content

Commit 98588cd

Browse files
committed
Fix test failure and add assert
1 parent 1aa0304 commit 98588cd

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

sycl/include/sycl/vector.hpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -789,8 +789,22 @@ template <typename Type, int NumElements> class vec {
789789
using bfloat16 = sycl::ext::oneapi::bfloat16;
790790
static_assert(std::is_integral_v<vec_data_t<convertT>> ||
791791
detail::is_floating_point<convertT>::value ||
792-
std::is_same_v<convertT, bfloat16>,
792+
// Conversion to BF16 available only for float.
793+
(std::is_same_v<convertT, bfloat16> &&
794+
std::is_same_v<DataT, float>),
793795
"Unsupported convertT");
796+
797+
// Currently, for BF16 <--> float conversion, we only support
798+
// Round-to-even rounding mode.
799+
constexpr bool isFloatToBF16Conv = std::is_same_v<convertT, bfloat16> &&
800+
std::is_same_v<DataT, float>;
801+
constexpr bool isBF16ToFloatConv = std::is_same_v<DataT, bfloat16> &&
802+
std::is_same_v<convertT, float>;
803+
if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) {
804+
static_assert(roundingMode == rounding_mode::automatic ||
805+
roundingMode == rounding_mode::rte);
806+
}
807+
794808
using T = vec_data_t<DataT>;
795809
using R = vec_data_t<convertT>;
796810
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
@@ -831,18 +845,18 @@ template <typename Type, int NumElements> class vec {
831845
// Otherwise, we fallback to per-element conversion:
832846
for (size_t I = 0; I < NumElements; ++I) {
833847
// For float -> bf16.
834-
if constexpr (std::is_same_v<convertT, bfloat16>) {
848+
if constexpr (isFloatToBF16Conv) {
835849
Result[I] = bfloat16((*this)[I]);
836850
} else
837851
// For bf16 -> float.
838-
if constexpr (std::is_same_v<DataT, bfloat16>) {
852+
if constexpr (isBF16ToFloatConv) {
839853
Result[I] = (float)((*this)[I]);
840854
}
841855
else {
842856
Result.setValue(
843857
I, vec_data<convertT>::get(
844858
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>(
845-
vec_data<T>::get(getValue(I)))));
859+
vec_data<DataT>::get(getValue(I)))));
846860
}
847861
}
848862
}

sycl/include/sycl/vector_preview.hpp

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,10 +419,24 @@ class vec : public detail::vec_arith<DataT, NumElements> {
419419

420420
using T = ConvertBoolAndByteT<DataT>;
421421
using R = ConvertBoolAndByteT<convertT>;
422+
using bfloat16 = sycl::ext::oneapi::bfloat16;
422423
static_assert(std::is_integral_v<R> || detail::is_floating_point<R>::value ||
423-
std::is_same_v<R, sycl::ext::oneapi::bfloat16>,
424+
std::is_same_v<R, bfloat16>,
424425
"Unsupported convertT");
425426

427+
{
428+
// Currently, for BF16 <--> float conversion, we only support
429+
// Round-to-even rounding mode.
430+
constexpr bool isFloatToBF16Conv = std::is_same_v<convertT, bfloat16> &&
431+
std::is_same_v<DataT, float>;
432+
constexpr bool isBF16ToFloatConv = std::is_same_v<DataT, bfloat16> &&
433+
std::is_same_v<convertT, float>;
434+
if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) {
435+
static_assert(roundingMode == rounding_mode::automatic ||
436+
roundingMode == rounding_mode::rte);
437+
}
438+
}
439+
426440
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
427441
using OpenCLR = detail::ConvertToOpenCLType_t<R>;
428442
vec<convertT, NumElements> Result;
@@ -482,7 +496,7 @@ class vec : public detail::vec_arith<DataT, NumElements> {
482496
getValue(I));
483497
#ifdef __SYCL_DEVICE_ONLY__
484498
// On device, we interpret BF16 as uint16.
485-
if constexpr (std::is_same_v<convertT, sycl::ext::oneapi::bfloat16>)
499+
if constexpr (std::is_same_v<convertT, bfloat16>)
486500
Result[I] = sycl::bit_cast<convertT>(val);
487501
else
488502
#endif

0 commit comments

Comments
 (0)