@@ -42,6 +42,24 @@ namespace sub_group {
42
42
template <typename T>
43
43
using SelectBlockT = select_cl_scalar_integral_unsigned_t <T>;
44
44
45
+ template <typename MultiPtrTy> auto convertToBlockPtr (MultiPtrTy MultiPtr) {
46
+ static_assert (is_multi_ptr_v<MultiPtrTy>);
47
+ auto DecoratedPtr = convertToOpenCLType (MultiPtr);
48
+ using DecoratedPtrTy = decltype (DecoratedPtr);
49
+ using ElemTy = remove_decoration_t <std::remove_pointer_t <DecoratedPtrTy>>;
50
+
51
+ using TargetElemTy = SelectBlockT<ElemTy>;
52
+ // TODO: Handle cv qualifiers.
53
+ #ifdef __SYCL_DEVICE_ONLY__
54
+ using ResultTy =
55
+ typename DecoratedType<TargetElemTy,
56
+ deduce_AS<DecoratedPtrTy>::value>::type *;
57
+ #else
58
+ using ResultTy = TargetElemTy *;
59
+ #endif
60
+ return reinterpret_cast <ResultTy>(DecoratedPtr);
61
+ }
62
+
45
63
template <typename T, access::address_space Space>
46
64
using AcceptableForGlobalLoadStore =
47
65
std::bool_constant<!std::is_same_v<void , SelectBlockT<T>> &&
@@ -57,11 +75,7 @@ template <typename T, access::address_space Space,
57
75
access::decorated DecorateAddress>
58
76
T load (const multi_ptr<T, Space, DecorateAddress> src) {
59
77
using BlockT = SelectBlockT<T>;
60
- using PtrT = sycl::detail::ConvertToOpenCLType_t<
61
- const multi_ptr<BlockT, Space, DecorateAddress>>;
62
-
63
- BlockT Ret =
64
- __spirv_SubgroupBlockReadINTEL<BlockT>(reinterpret_cast <PtrT>(src.get ()));
78
+ BlockT Ret = __spirv_SubgroupBlockReadINTEL<BlockT>(convertToBlockPtr (src));
65
79
66
80
return sycl::bit_cast<T>(Ret);
67
81
}
@@ -71,11 +85,7 @@ template <int N, typename T, access::address_space Space,
71
85
vec<T, N> load (const multi_ptr<T, Space, DecorateAddress> src) {
72
86
using BlockT = SelectBlockT<T>;
73
87
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
74
- using PtrT = sycl::detail::ConvertToOpenCLType_t<
75
- const multi_ptr<BlockT, Space, DecorateAddress>>;
76
-
77
- VecT Ret =
78
- __spirv_SubgroupBlockReadINTEL<VecT>(reinterpret_cast <PtrT>(src.get ()));
88
+ VecT Ret = __spirv_SubgroupBlockReadINTEL<VecT>(convertToBlockPtr (src));
79
89
80
90
return sycl::bit_cast<typename vec<T, N>::vector_t >(Ret);
81
91
}
@@ -84,10 +94,8 @@ template <typename T, access::address_space Space,
84
94
access::decorated DecorateAddress>
85
95
void store (multi_ptr<T, Space, DecorateAddress> dst, const T &x) {
86
96
using BlockT = SelectBlockT<T>;
87
- using PtrT = sycl::detail::ConvertToOpenCLType_t<
88
- multi_ptr<BlockT, Space, DecorateAddress>>;
89
97
90
- __spirv_SubgroupBlockWriteINTEL (reinterpret_cast <PtrT> (dst. get () ),
98
+ __spirv_SubgroupBlockWriteINTEL (convertToBlockPtr (dst),
91
99
sycl::bit_cast<BlockT>(x));
92
100
}
93
101
@@ -96,10 +104,8 @@ template <int N, typename T, access::address_space Space,
96
104
void store (multi_ptr<T, Space, DecorateAddress> dst, const vec<T, N> &x) {
97
105
using BlockT = SelectBlockT<T>;
98
106
using VecT = sycl::detail::ConvertToOpenCLType_t<vec<BlockT, N>>;
99
- using PtrT = sycl::detail::ConvertToOpenCLType_t<
100
- const multi_ptr<BlockT, Space, DecorateAddress>>;
101
107
102
- __spirv_SubgroupBlockWriteINTEL (reinterpret_cast <PtrT> (dst. get () ),
108
+ __spirv_SubgroupBlockWriteINTEL (convertToBlockPtr (dst),
103
109
sycl::bit_cast<VecT>(x));
104
110
}
105
111
#endif // __SYCL_DEVICE_ONLY__
0 commit comments