@@ -789,8 +789,22 @@ template <typename Type, int NumElements> class vec {
789
789
using bfloat16 = sycl::ext::oneapi::bfloat16;
790
790
static_assert (std::is_integral_v<vec_data_t <convertT>> ||
791
791
detail::is_floating_point<convertT>::value ||
792
- std::is_same_v<convertT, bfloat16>,
792
+ // Conversion to BF16 available only for float.
793
+ (std::is_same_v<convertT, bfloat16> &&
794
+ std::is_same_v<DataT, float >),
793
795
" Unsupported convertT" );
796
+
797
+ // Currently, for BF16 <--> float conversion, we only support
798
+ // Round-to-even rounding mode.
799
+ constexpr bool isFloatToBF16Conv = std::is_same_v<convertT, bfloat16> &&
800
+ std::is_same_v<DataT, float >;
801
+ constexpr bool isBF16ToFloatConv = std::is_same_v<DataT, bfloat16> &&
802
+ std::is_same_v<convertT, float >;
803
+ if constexpr (isFloatToBF16Conv || isBF16ToFloatConv) {
804
+ static_assert (roundingMode == rounding_mode::automatic ||
805
+ roundingMode == rounding_mode::rte);
806
+ }
807
+
794
808
using T = vec_data_t <DataT>;
795
809
using R = vec_data_t <convertT>;
796
810
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
@@ -831,18 +845,18 @@ template <typename Type, int NumElements> class vec {
831
845
// Otherwise, we fallback to per-element conversion:
832
846
for (size_t I = 0 ; I < NumElements; ++I) {
833
847
// For float -> bf16.
834
- if constexpr (std::is_same_v<convertT, bfloat16> ) {
848
+ if constexpr (isFloatToBF16Conv ) {
835
849
Result[I] = bfloat16 ((*this )[I]);
836
850
} else
837
851
// For bf16 -> float.
838
- if constexpr (std::is_same_v<DataT, bfloat16> ) {
852
+ if constexpr (isBF16ToFloatConv ) {
839
853
Result[I] = (float )((*this )[I]);
840
854
}
841
855
else {
842
856
Result.setValue (
843
857
I, vec_data<convertT>::get (
844
858
detail::convertImpl<T, R, roundingMode, 1 , OpenCLT, OpenCLR>(
845
- vec_data<T >::get (getValue (I)))));
859
+ vec_data<DataT >::get (getValue (I)))));
846
860
}
847
861
}
848
862
}
0 commit comments