@@ -732,7 +732,8 @@ template <typename Type, int NumElements> class vec {
732
732
if constexpr (!IsUsingArrayOnDevice) {
733
733
return m_Data;
734
734
} else {
735
- return sycl::bit_cast<vector_t >(m_Data);
735
+ auto ptr = bit_cast<const vector_t *>((&m_Data)->data ());
736
+ return *ptr;
736
737
}
737
738
}
738
739
#endif // __SYCL_DEVICE_ONLY__
@@ -788,77 +789,64 @@ template <typename Type, int NumElements> class vec {
788
789
using bfloat16 = sycl::ext::oneapi::bfloat16;
789
790
static_assert (std::is_integral_v<vec_data_t <convertT>> ||
790
791
detail::is_floating_point<convertT>::value ||
791
- std::is_same_v<bfloat16, convertT >,
792
+ std::is_same_v<convertT, bfloat16 >,
792
793
" Unsupported convertT" );
793
794
using T = vec_data_t <DataT>;
794
795
using R = vec_data_t <convertT>;
795
796
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
796
797
using OpenCLR = detail::ConvertToOpenCLType_t<R>;
797
-
798
798
vec<convertT, NumElements> Result;
799
799
800
- // we are not on CUDA, see intel/llvm#11840
801
- #if defined(__SYCL_DEVICE_ONLY__) && !defined(__NVPTX__)
802
- // Convert BF16 vector -> float vector and vice versa.
803
- if constexpr (((IsBfloat16 && std::is_same_v<convertT, float >) ||
804
- (std::is_same_v<convertT, bfloat16> &&
805
- std::is_same_v<DataT, float >)) &&
806
- NumElements > 1 ) {
807
-
808
- using BF16ExtType = sycl::ext::oneapi::detail::Bfloat16StorageT
809
- __attribute__ ((ext_vector_type (NumElements)));
810
- using FloatExtType = float __attribute__ ((ext_vector_type (NumElements)));
811
- vec<convertT, NumElements> convertedVec;
812
-
813
- if constexpr (IsBfloat16)
814
- convertedVec =
815
- detail::convertImpl<bfloat16, float , roundingMode, NumElements,
816
- BF16ExtType, FloatExtType>(
817
- static_cast <vector_t >(*this ));
818
- else
819
- convertedVec =
820
- detail::convertImpl<float , bfloat16, roundingMode, NumElements,
821
- FloatExtType, BF16ExtType>(
822
- static_cast <vector_t >(*this ));
823
-
824
- return vec<convertT, NumElements>(convertedVec);
825
- } else if constexpr (NumElements > 1 ) {
826
- using OpenCLVecT = OpenCLT __attribute__ ((ext_vector_type (NumElements)));
827
- using OpenCLVecR = OpenCLR __attribute__ ((ext_vector_type (NumElements)));
828
- // Whole vector conversion can only be done, if:
829
- constexpr bool canUseNativeVectorConvert =
830
- // - both vectors are represented using native vector types;
831
- NativeVec && vec<convertT, NumElements>::NativeVec &&
832
- // - vec storage has an equivalent OpenCL native vector it is
833
- // implicitly
834
- // convertible to. There are some corner cases where it is not the
835
- // case with char, long and long long types.
836
- std::is_convertible_v<decltype (m_Data), OpenCLVecT> &&
837
- std::is_convertible_v<decltype (Result.m_Data ), OpenCLVecR> &&
838
- // - it is not a signed to unsigned (or vice versa) conversion
839
- // see comments within 'convertImpl' for more details;
840
- !detail::is_sint_to_from_uint<T, R>::value &&
841
- // - destination type is not bool. bool is stored as integer under the
842
- // hood and therefore conversion to bool looks like conversion
843
- // between two integer types. Since bit pattern for true and false
844
- // is not defined, there is no guarantee that integer conversion
845
- // yields right results here;
846
- !std::is_same_v<convertT, bool >;
847
- if constexpr (canUseNativeVectorConvert) {
848
- Result.m_Data = detail::convertImpl<T, R, roundingMode, NumElements,
849
- OpenCLVecT, OpenCLVecR>(m_Data);
850
- return Result;
800
+ #if defined(__SYCL_DEVICE_ONLY__)
801
+ using OpenCLVecT = OpenCLT __attribute__ ((ext_vector_type (NumElements)));
802
+ using OpenCLVecR = OpenCLR __attribute__ ((ext_vector_type (NumElements)));
803
+ // Whole vector conversion can only be done, if:
804
+ constexpr bool canUseNativeVectorConvert =
805
+ #ifdef __NVPTX__
806
+ // - we are not on CUDA, see intel/llvm#11840
807
+ false &&
808
+ #endif
809
+ // - both vectors are represented using native vector types;
810
+ NativeVec && vec<convertT, NumElements>::NativeVec &&
811
+ // - vec storage has an equivalent OpenCL native vector it is implicitly
812
+ // convertible to. There are some corner cases where it is not the
813
+ // case with char, long and long long types.
814
+ std::is_convertible_v<decltype (m_Data), OpenCLVecT> &&
815
+ std::is_convertible_v<decltype (Result.m_Data ), OpenCLVecR> &&
816
+ // - it is not a signed to unsigned (or vice versa) conversion
817
+ // see comments within 'convertImpl' for more details;
818
+ !detail::is_sint_to_from_uint<T, R>::value &&
819
+ // - destination type is not bool. bool is stored as integer under the
820
+ // hood and therefore conversion to bool looks like conversion between
821
+ // two integer types. Since bit pattern for true and false is not
822
+ // defined, there is no guarantee that integer conversion yields
823
+ // right results here;
824
+ !std::is_same_v<convertT, bool >;
825
+ if constexpr (canUseNativeVectorConvert) {
826
+ Result.m_Data = detail::convertImpl<T, R, roundingMode, NumElements,
827
+ OpenCLVecT, OpenCLVecR>(m_Data);
828
+ } else
829
+ #endif // defined(__SYCL_DEVICE_ONLY__)
830
+ {
831
+ // Otherwise, we fallback to per-element conversion:
832
+ for (size_t I = 0 ; I < NumElements; ++I) {
833
+ // For float -> bf16.
834
+ if constexpr (std::is_same_v<convertT, bfloat16>) {
835
+ Result[I] = bfloat16 ((*this )[I]);
836
+ } else
837
+ // For bf16 -> float.
838
+ if constexpr (std::is_same_v<DataT, bfloat16>) {
839
+ Result[I] = (float )((*this )[I]);
840
+ }
841
+ else {
842
+ Result.setValue (
843
+ I, vec_data<convertT>::get (
844
+ detail::convertImpl<T, R, roundingMode, 1 , OpenCLT, OpenCLR>(
845
+ vec_data<T>::get (getValue (I)))));
846
+ }
851
847
}
852
848
}
853
- #endif // defined(__SYCL_DEVICE_ONLY__)
854
849
855
- // Otherwise, we fallback to per-element conversion:
856
- for (size_t I = 0 ; I < NumElements; ++I) {
857
- Result.setValue (
858
- I, vec_data<convertT>::get (
859
- detail::convertImpl<T, R, roundingMode, 1 , OpenCLT, OpenCLR>(
860
- vec_data<DataT>::get (getValue (I)))));
861
- }
862
850
return Result;
863
851
}
864
852
0 commit comments