Skip to content

Commit e5da67b

Browse files
[NFCI][SYCL] Support multi_ptr in convertToOpenCLType (#12693)
1 parent 7999e27 commit e5da67b

File tree

3 files changed

+62
-38
lines changed

3 files changed

+62
-38
lines changed

sycl/include/sycl/detail/generic_type_traits.hpp

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,6 +663,14 @@ template <> struct ConvertToOpenCLTypeImpl<Boolean<1>> {
663663
// Or should it be "int"?
664664
using type = Boolean<1>;
665665
};
666+
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
667+
// TODO: It seems we only use this to convert a pointer's element type. As such,
668+
// although it doesn't look very clean, it should be ok having this case handled
669+
// explicitly until further refactoring of this area.
670+
template <> struct ConvertToOpenCLTypeImpl<std::byte> {
671+
using type = uint8_t;
672+
};
673+
#endif
666674
#endif
667675

668676
template <typename T> struct ConvertToOpenCLTypeImpl<T *> {
@@ -700,8 +708,30 @@ convertDataToType(FROM t) {
700708
// Now fuse the above into a simpler helper that's easy to use.
701709
// TODO: That should probably be moved outside of "type_traits".
702710
template <typename T> auto convertToOpenCLType(T &&x) {
703-
using OpenCLType = ConvertToOpenCLType_t<std::remove_reference_t<T>>;
704-
return convertDataToType<T, OpenCLType>(std::forward<T>(x));
711+
using no_ref = std::remove_reference_t<T>;
712+
if constexpr (is_multi_ptr_v<no_ref>) {
713+
return convertToOpenCLType(x.get_decorated());
714+
} else if constexpr (std::is_pointer_v<no_ref>) {
715+
// TODO: Below ignores volatile, but we didn't have a need for it yet.
716+
using elem_type = remove_decoration_t<std::remove_pointer_t<no_ref>>;
717+
using converted_elem_type_no_cv =
718+
ConvertToOpenCLType_t<std::remove_const_t<elem_type>>;
719+
using converted_elem_type =
720+
std::conditional_t<std::is_const_v<elem_type>,
721+
const converted_elem_type_no_cv,
722+
converted_elem_type_no_cv>;
723+
#ifdef __SYCL_DEVICE_ONLY__
724+
using result_type =
725+
typename DecoratedType<converted_elem_type,
726+
deduce_AS<no_ref>::value>::type *;
727+
#else
728+
using result_type = converted_elem_type *;
729+
#endif
730+
return reinterpret_cast<result_type>(x);
731+
} else {
732+
using OpenCLType = ConvertToOpenCLType_t<no_ref>;
733+
return convertDataToType<T, OpenCLType>(std::forward<T>(x));
734+
}
705735
}
706736

707737
template <typename To, typename From> auto convertFromOpenCLTypeFor(From &&x) {

sycl/include/sycl/group.hpp

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
315315
global_ptr<dataT> src,
316316
size_t numElements,
317317
size_t srcStride) const {
318-
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
319-
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
320-
321318
__ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal(
322-
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
323-
numElements, srcStride, 0);
319+
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
320+
detail::convertToOpenCLType(src), numElements, srcStride, 0);
324321
return device_event(E);
325322
}
326323

@@ -337,12 +334,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
337334
size_t numElements,
338335
size_t destStride)
339336
const {
340-
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
341-
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
342-
343337
__ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal(
344-
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
345-
numElements, destStride, 0);
338+
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
339+
detail::convertToOpenCLType(src), numElements, destStride, 0);
346340
return device_event(E);
347341
}
348342

@@ -359,12 +353,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
359353
async_work_group_copy(decorated_local_ptr<DestDataT> dest,
360354
decorated_global_ptr<SrcDataT> src, size_t numElements,
361355
size_t srcStride) const {
362-
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
363-
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
364-
365356
__ocl_event_t E = __SYCL_OpGroupAsyncCopyGlobalToLocal(
366-
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
367-
numElements, srcStride, 0);
357+
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
358+
detail::convertToOpenCLType(src), numElements, srcStride, 0);
368359
return device_event(E);
369360
}
370361

@@ -381,12 +372,9 @@ template <int Dimensions = 1> class __SYCL_TYPE(group) group {
381372
async_work_group_copy(decorated_global_ptr<DestDataT> dest,
382373
decorated_local_ptr<SrcDataT> src, size_t numElements,
383374
size_t destStride) const {
384-
using DestT = detail::ConvertToOpenCLType_t<decltype(dest)>;
385-
using SrcT = detail::ConvertToOpenCLType_t<decltype(src)>;
386-
387375
__ocl_event_t E = __SYCL_OpGroupAsyncCopyLocalToGlobal(
388-
__spv::Scope::Workgroup, DestT(dest.get()), SrcT(src.get()),
389-
numElements, destStride, 0);
376+
__spv::Scope::Workgroup, detail::convertToOpenCLType(dest),
377+
detail::convertToOpenCLType(src), numElements, destStride, 0);
390378
return device_event(E);
391379
}
392380

sycl/include/sycl/sub_group.hpp

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,24 @@ namespace sub_group {
4242
template <typename T>
4343
using SelectBlockT = select_cl_scalar_integral_unsigned_t<T>;
4444

45+
template <typename MultiPtrTy> auto convertToBlockPtr(MultiPtrTy MultiPtr) {
46+
static_assert(is_multi_ptr_v<MultiPtrTy>);
47+
auto DecoratedPtr = convertToOpenCLType(MultiPtr);
48+
using DecoratedPtrTy = decltype(DecoratedPtr);
49+
using ElemTy = remove_decoration_t<std::remove_pointer_t<DecoratedPtrTy>>;
50+
51+
using TargetElemTy = SelectBlockT<ElemTy>;
52+
// TODO: Handle cv qualifiers.
53+
#ifdef __SYCL_DEVICE_ONLY__
54+
using ResultTy =
55+
typename DecoratedType<TargetElemTy,
56+
deduce_AS<DecoratedPtrTy>::value>::type *;
57+
#else
58+
using ResultTy = TargetElemTy *;
59+
#endif
60+
return reinterpret_cast<ResultTy>(DecoratedPtr);
61+
}
62+
4563
template <typename T, access::address_space Space>
4664
using AcceptableForGlobalLoadStore =
4765
std::bool_constant<!std::is_same_v<void, SelectBlockT<T>> &&
@@ -57,11 +75,7 @@ template <typename T, access::address_space Space,
5775
access::decorated DecorateAddress>
5876
T load(const multi_ptr<T, Space, DecorateAddress> src) {
5977
using BlockT = SelectBlockT<T>;
60-
using PtrT = sycl::detail::ConvertToOpenCLType_t<
61-
const multi_ptr<BlockT, Space, DecorateAddress>>;
62-
63-
BlockT Ret =
64-
__spirv_SubgroupBlockReadINTEL<BlockT>(reinterpret_cast<PtrT>(src.get()));
78+
BlockT Ret = __spirv_SubgroupBlockReadINTEL<BlockT>(convertToBlockPtr(src));
6579

6680
return sycl::bit_cast<T>(Ret);
6781
}
@@ -71,11 +85,7 @@ template <int N, typename T, access::address_space Space,
7185
vec<T, N> load(const multi_ptr<T, Space, DecorateAddress> src) {
7286
using BlockT = SelectBlockT<T>;
7387
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
74-
using PtrT = sycl::detail::ConvertToOpenCLType_t<
75-
const multi_ptr<BlockT, Space, DecorateAddress>>;
76-
77-
VecT Ret =
78-
__spirv_SubgroupBlockReadINTEL<VecT>(reinterpret_cast<PtrT>(src.get()));
88+
VecT Ret = __spirv_SubgroupBlockReadINTEL<VecT>(convertToBlockPtr(src));
7989

8090
return sycl::bit_cast<typename vec<T, N>::vector_t>(Ret);
8191
}
@@ -84,10 +94,8 @@ template <typename T, access::address_space Space,
8494
access::decorated DecorateAddress>
8595
void store(multi_ptr<T, Space, DecorateAddress> dst, const T &x) {
8696
using BlockT = SelectBlockT<T>;
87-
using PtrT = sycl::detail::ConvertToOpenCLType_t<
88-
multi_ptr<BlockT, Space, DecorateAddress>>;
8997

90-
__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
98+
__spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst),
9199
sycl::bit_cast<BlockT>(x));
92100
}
93101

@@ -96,10 +104,8 @@ template <int N, typename T, access::address_space Space,
96104
void store(multi_ptr<T, Space, DecorateAddress> dst, const vec<T, N> &x) {
97105
using BlockT = SelectBlockT<T>;
98106
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
99-
using PtrT = sycl::detail::ConvertToOpenCLType_t<
100-
const multi_ptr<BlockT, Space, DecorateAddress>>;
101107

102-
__spirv_SubgroupBlockWriteINTEL(reinterpret_cast<PtrT>(dst.get()),
108+
__spirv_SubgroupBlockWriteINTEL(convertToBlockPtr(dst),
103109
sycl::bit_cast<VecT>(x));
104110
}
105111
#endif // __SYCL_DEVICE_ONLY__

0 commit comments

Comments
 (0)