Skip to content

Commit 7651a30

Browse files
committed
fix for older vec implementation
1 parent c60ec69 commit 7651a30

File tree

2 files changed

+56
-68
lines changed

2 files changed

+56
-68
lines changed

sycl/include/sycl/vector.hpp

Lines changed: 50 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -732,7 +732,8 @@ template <typename Type, int NumElements> class vec {
732732
if constexpr (!IsUsingArrayOnDevice) {
733733
return m_Data;
734734
} else {
735-
return sycl::bit_cast<vector_t>(m_Data);
735+
auto ptr = bit_cast<const vector_t *>((&m_Data)->data());
736+
return *ptr;
736737
}
737738
}
738739
#endif // __SYCL_DEVICE_ONLY__
@@ -788,77 +789,64 @@ template <typename Type, int NumElements> class vec {
788789
using bfloat16 = sycl::ext::oneapi::bfloat16;
789790
static_assert(std::is_integral_v<vec_data_t<convertT>> ||
790791
detail::is_floating_point<convertT>::value ||
791-
std::is_same_v<bfloat16, convertT>,
792+
std::is_same_v<convertT, bfloat16>,
792793
"Unsupported convertT");
793794
using T = vec_data_t<DataT>;
794795
using R = vec_data_t<convertT>;
795796
using OpenCLT = detail::ConvertToOpenCLType_t<T>;
796797
using OpenCLR = detail::ConvertToOpenCLType_t<R>;
797-
798798
vec<convertT, NumElements> Result;
799799

800-
// we are not on CUDA, see intel/llvm#11840
801-
#if defined(__SYCL_DEVICE_ONLY__) && !defined(__NVPTX__)
802-
// Convert BF16 vector -> float vector and vice versa.
803-
if constexpr (((IsBfloat16 && std::is_same_v<convertT, float>) ||
804-
(std::is_same_v<convertT, bfloat16> &&
805-
std::is_same_v<DataT, float>)) &&
806-
NumElements > 1) {
807-
808-
using BF16ExtType = sycl::ext::oneapi::detail::Bfloat16StorageT
809-
__attribute__((ext_vector_type(NumElements)));
810-
using FloatExtType = float __attribute__((ext_vector_type(NumElements)));
811-
vec<convertT, NumElements> convertedVec;
812-
813-
if constexpr (IsBfloat16)
814-
convertedVec =
815-
detail::convertImpl<bfloat16, float, roundingMode, NumElements,
816-
BF16ExtType, FloatExtType>(
817-
static_cast<vector_t>(*this));
818-
else
819-
convertedVec =
820-
detail::convertImpl<float, bfloat16, roundingMode, NumElements,
821-
FloatExtType, BF16ExtType>(
822-
static_cast<vector_t>(*this));
823-
824-
return vec<convertT, NumElements>(convertedVec);
825-
} else if constexpr (NumElements > 1) {
826-
using OpenCLVecT = OpenCLT __attribute__((ext_vector_type(NumElements)));
827-
using OpenCLVecR = OpenCLR __attribute__((ext_vector_type(NumElements)));
828-
// Whole vector conversion can only be done, if:
829-
constexpr bool canUseNativeVectorConvert =
830-
// - both vectors are represented using native vector types;
831-
NativeVec && vec<convertT, NumElements>::NativeVec &&
832-
// - vec storage has an equivalent OpenCL native vector it is
833-
// implicitly
834-
// convertible to. There are some corner cases where it is not the
835-
// case with char, long and long long types.
836-
std::is_convertible_v<decltype(m_Data), OpenCLVecT> &&
837-
std::is_convertible_v<decltype(Result.m_Data), OpenCLVecR> &&
838-
// - it is not a signed to unsigned (or vice versa) conversion
839-
// see comments within 'convertImpl' for more details;
840-
!detail::is_sint_to_from_uint<T, R>::value &&
841-
// - destination type is not bool. bool is stored as integer under the
842-
// hood and therefore conversion to bool looks like conversion
843-
// between two integer types. Since bit pattern for true and false
844-
// is not defined, there is no guarantee that integer conversion
845-
// yields right results here;
846-
!std::is_same_v<convertT, bool>;
847-
if constexpr (canUseNativeVectorConvert) {
848-
Result.m_Data = detail::convertImpl<T, R, roundingMode, NumElements,
849-
OpenCLVecT, OpenCLVecR>(m_Data);
850-
return Result;
800+
#if defined(__SYCL_DEVICE_ONLY__)
801+
using OpenCLVecT = OpenCLT __attribute__((ext_vector_type(NumElements)));
802+
using OpenCLVecR = OpenCLR __attribute__((ext_vector_type(NumElements)));
803+
// Whole vector conversion can only be done, if:
804+
constexpr bool canUseNativeVectorConvert =
805+
#ifdef __NVPTX__
806+
// - we are not on CUDA, see intel/llvm#11840
807+
false &&
808+
#endif
809+
// - both vectors are represented using native vector types;
810+
NativeVec && vec<convertT, NumElements>::NativeVec &&
811+
// - vec storage has an equivalent OpenCL native vector it is implicitly
812+
// convertible to. There are some corner cases where it is not the
813+
// case with char, long and long long types.
814+
std::is_convertible_v<decltype(m_Data), OpenCLVecT> &&
815+
std::is_convertible_v<decltype(Result.m_Data), OpenCLVecR> &&
816+
// - it is not a signed to unsigned (or vice versa) conversion
817+
// see comments within 'convertImpl' for more details;
818+
!detail::is_sint_to_from_uint<T, R>::value &&
819+
// - destination type is not bool. bool is stored as integer under the
820+
// hood and therefore conversion to bool looks like conversion between
821+
// two integer types. Since bit pattern for true and false is not
822+
// defined, there is no guarantee that integer conversion yields
823+
// right results here;
824+
!std::is_same_v<convertT, bool>;
825+
if constexpr (canUseNativeVectorConvert) {
826+
Result.m_Data = detail::convertImpl<T, R, roundingMode, NumElements,
827+
OpenCLVecT, OpenCLVecR>(m_Data);
828+
} else
829+
#endif // defined(__SYCL_DEVICE_ONLY__)
830+
{
831+
// Otherwise, we fallback to per-element conversion:
832+
for (size_t I = 0; I < NumElements; ++I) {
833+
// For float -> bf16.
834+
if constexpr (std::is_same_v<convertT, bfloat16>) {
835+
Result[I] = bfloat16((*this)[I]);
836+
} else
837+
// For bf16 -> float.
838+
if constexpr (std::is_same_v<DataT, bfloat16>) {
839+
Result[I] = (float)((*this)[I]);
840+
}
841+
else {
842+
Result.setValue(
843+
I, vec_data<convertT>::get(
844+
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>(
845+
vec_data<T>::get(getValue(I)))));
846+
}
851847
}
852848
}
853-
#endif // defined(__SYCL_DEVICE_ONLY__)
854849

855-
// Otherwise, we fallback to per-element conversion:
856-
for (size_t I = 0; I < NumElements; ++I) {
857-
Result.setValue(
858-
I, vec_data<convertT>::get(
859-
detail::convertImpl<T, R, roundingMode, 1, OpenCLT, OpenCLR>(
860-
vec_data<DataT>::get(getValue(I)))));
861-
}
862850
return Result;
863851
}
864852

sycl/test-e2e/BFloat16/bfloat16_vec.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
// TODO enable opaque pointers support on CPU.
1111
// UNSUPPORTED: cpu || accelerator
1212

13-
// RUN: %{build} -o %t.out
14-
// RUN: %{run} %t.out
13+
// UN: %{build} -o %t.out
14+
// UN: %{run} %t.out
1515
// RUN: %if preview-breaking-changes-supported %{ %{build} -fpreview-breaking-changes -o %t2.out %}
1616
// RUN: %if preview-breaking-changes-supported %{ %{run} %t2.out %}
1717

@@ -140,8 +140,8 @@ int main() {
140140
std::cout << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << std::endl;
141141
std::cout << "div[0]: " << double_division[0] << " div[1]: " << double_division[1] << std::endl;
142142
std::cout << "Float convert ref0: " << double_float[0] << " ref1: " << double_float[1] << std::endl;
143-
std::cout << "convert[0]: " << fConv2[0] << " convert[1]: " << fConv2[1] << std::endl;
144-
std::cout << "bf16 convert[0]: " << brev2[0] << " bf16 convert[1]: " << brev2[1] << std::endl;
143+
std::cout << "convert[0]: " << fConv2[0] << " convert[1]: " << fConv2[1] << std::endl;
144+
std::cout << "bf16 convert[0]: " << brev2[0] << " bf16 convert[1]: " << brev2[1] << std::endl;
145145

146146
assert(twoA[0] == double_float[0]); assert(twoA[1] == double_float[1]);
147147
assert(addition_ref0 == double_addition[0]); assert(addition_ref1 == double_addition[1]);
@@ -178,8 +178,8 @@ int main() {
178178
out << "/ ref0: " << division_ref0 << " ref1: " << division_ref1 << sycl::endl;
179179
out << "div[0]: " << device_division[0] << " div[1]: " << device_division[1] << sycl::endl;
180180
out << "Float convert ref0: " << device_float[0] << " ref1: " << device_float[1] << sycl::endl;
181-
out << "convert[0]: " << fConv2[0] << " convert[1]: " << fConv2[1] << sycl::endl;
182-
out << "bf16 convert[0]: " << brev2[0] << " bf16 convert[1]: " << brev2[1] << sycl::endl;
181+
out << "convert[0]: " << fConv2[0] << " convert[1]: " << fConv2[1] << sycl::endl;
182+
out << "bf16 convert[0]: " << brev2[0] << " bf16 convert[1]: " << brev2[1] << sycl::endl;
183183

184184
acc[7] = (twoA[0] == device_float[0]) && (twoA[1] == device_float[1]);
185185
acc[8] = (addition_ref0 == device_addition[0]) && (addition_ref1 == device_addition[1]);

0 commit comments

Comments
 (0)