Skip to content

Commit 712cb4e

Browse files
[SYCL] Update get_pointer to return T* for target::device specialized accessor. (#8874)
* Update get_pointer to return T* for target::device specialized accessor according to the [Specification](https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_buffer_accessor_for_commands) * Declare src of `async_work_group_copy` in `group.hpp` and `nd_item.hpp` as `global_ptr<const>` * Implement implicit conversion from `multi_ptr<T>` to `multi_ptr<const T>` * Modifies `is_native_op` to cover both `const` and `non-const` types. --------- Co-authored-by: Steffen Larsen <[email protected]>
1 parent dbadecb commit 712cb4e

File tree

91 files changed

+1507
-994
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+1507
-994
lines changed

sycl/include/sycl/accessor.hpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2124,23 +2124,12 @@ class __SYCL_EBO __SYCL_SPECIAL_CLASS __SYCL_TYPE(accessor) accessor :
21242124
template <access::target AccessTarget_ = AccessTarget,
21252125
typename = std::enable_if_t<
21262126
(AccessTarget_ == access::target::host_buffer) ||
2127-
(AccessTarget_ == access::target::host_task)>>
2128-
#if SYCL_LANGUAGE_VERSION >= 202001
2129-
std::add_pointer_t<value_type> get_pointer() const noexcept
2130-
#else
2131-
DataT *get_pointer() const
2132-
#endif
2133-
{
2127+
(AccessTarget_ == access::target::host_task) ||
2128+
(AccessTarget_ == access::target::device)>>
2129+
std::add_pointer_t<value_type> get_pointer() const noexcept {
21342130
return getPointerAdjusted();
21352131
}
21362132

2137-
template <
2138-
access::target AccessTarget_ = AccessTarget,
2139-
typename = std::enable_if_t<AccessTarget_ == access::target::device>>
2140-
global_ptr<DataT> get_pointer() const {
2141-
return global_ptr<DataT>(getPointerAdjusted());
2142-
}
2143-
21442133
template <access::target AccessTarget_ = AccessTarget,
21452134
typename = std::enable_if_t<AccessTarget_ ==
21462135
access::target::constant_buffer>>

sycl/include/sycl/ext/intel/esimd/detail/util.hpp

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,9 +182,15 @@ template <unsigned N> class ForHelper {
182182
/// Returns the address referenced by the accessor \p Acc and
183183
/// the byte offset \p Offset.
184184
template <typename T, typename AccessorTy, typename OffsetTy = uint32_t>
185-
T *accessorToPointer(AccessorTy Acc, OffsetTy Offset = 0) {
186-
auto BytePtr = reinterpret_cast<char *>(Acc.get_pointer().get()) + Offset;
187-
return reinterpret_cast<T *>(BytePtr);
185+
auto accessorToPointer(AccessorTy Acc, OffsetTy Offset = 0) {
186+
using QualCharPtrType =
187+
std::conditional_t<std::is_const_v<typename AccessorTy::value_type>,
188+
const char *, char *>;
189+
using QualTPtrType =
190+
std::conditional_t<std::is_const_v<typename AccessorTy::value_type>,
191+
const T *, T *>;
192+
auto BytePtr = reinterpret_cast<QualCharPtrType>(Acc.get_pointer()) + Offset;
193+
return reinterpret_cast<QualTPtrType>(BytePtr);
188194
}
189195
#endif // __ESIMD_FORCE_STATELESS_MEM
190196

sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -776,8 +776,7 @@ __ESIMD_API std::enable_if_t<!std::is_pointer_v<AccessorTy>,
776776
lsc_gather(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
777777
__ESIMD_NS::simd_mask<N> pred = 1) {
778778
#ifdef __ESIMD_FORCE_STATELESS_MEM
779-
return lsc_gather<T, NElts, DS, L1H, L3H>(acc.get_pointer().get(), offsets,
780-
pred);
779+
return lsc_gather<T, NElts, DS, L1H, L3H>(acc.get_pointer(), offsets, pred);
781780
#else
782781
detail::check_lsc_vector_size<NElts>();
783782
detail::check_lsc_data_size<T, DS>();
@@ -829,8 +828,8 @@ lsc_gather(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
829828
__ESIMD_NS::simd_mask<N> pred,
830829
__ESIMD_NS::simd<T, N * NElts> old_values) {
831830
#ifdef __ESIMD_FORCE_STATELESS_MEM
832-
return lsc_gather<T, NElts, DS, L1H, L3H>(acc.get_pointer().get(), offsets,
833-
pred, old_values);
831+
return lsc_gather<T, NElts, DS, L1H, L3H>(acc.get_pointer(), offsets, pred,
832+
old_values);
834833
#else
835834
detail::check_lsc_vector_size<NElts>();
836835
detail::check_lsc_data_size<T, DS>();

sycl/include/sycl/group_algorithm.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ using native_op_list =
9696

9797
template <typename T, typename BinaryOperation> struct is_native_op {
9898
static constexpr bool value =
99-
is_contained<BinaryOperation, native_op_list<T>>::value ||
99+
is_contained<BinaryOperation,
100+
native_op_list<std::remove_const_t<T>>>::value ||
101+
is_contained<BinaryOperation,
102+
native_op_list<std::add_const_t<T>>>::value ||
100103
is_contained<BinaryOperation, native_op_list<void>>::value;
101104
};
102105

@@ -123,9 +126,9 @@ struct is_complex
123126

124127
// ---- is_arithmetic_or_complex
125128
template <typename T>
126-
using is_arithmetic_or_complex =
127-
std::integral_constant<bool, sycl::detail::is_complex<T>::value ||
128-
sycl::detail::is_arithmetic<T>::value>;
129+
using is_arithmetic_or_complex = std::integral_constant<
130+
bool, sycl::detail::is_complex<typename std::remove_cv_t<T>>::value ||
131+
sycl::detail::is_arithmetic<T>::value>;
129132

130133
template <typename T>
131134
struct is_vector_arithmetic_or_complex

sycl/include/sycl/multi_ptr.hpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,16 @@ class multi_ptr {
113113
: m_Pointer(ptr) {}
114114
multi_ptr(std::nullptr_t) : m_Pointer(nullptr) {}
115115

116+
// Implicit conversion from multi_ptr<T> to multi_ptr<const T>
117+
template <typename NonConstElementType = std::remove_const_t<ElementType>,
118+
typename = typename std::enable_if_t<
119+
std::is_const_v<ElementType> &&
120+
std::is_same_v<NonConstElementType,
121+
std::remove_const_t<ElementType>>>>
122+
explicit multi_ptr(
123+
multi_ptr<NonConstElementType, Space, DecorateAddress> MPtr)
124+
: m_Pointer(MPtr.get_decorated()) {}
125+
116126
// Only if Space is in
117127
// {global_space, ext_intel_global_device_space, generic_space}
118128
template <
@@ -126,8 +136,7 @@ class multi_ptr {
126136
multi_ptr(accessor<ElementType, Dimensions, Mode, access::target::device,
127137
isPlaceholder, PropertyListT>
128138
Accessor)
129-
: multi_ptr(
130-
detail::cast_AS<decorated_type *>(Accessor.get_pointer().get())) {}
139+
: multi_ptr(Accessor.template get_multi_ptr<DecorateAddress>()) {}
131140

132141
// Only if Space == local_space || generic_space
133142
template <int Dimensions, access::mode Mode,
@@ -149,7 +158,7 @@ class multi_ptr {
149158
(Space == access::address_space::generic_space ||
150159
Space == access::address_space::local_space)>>
151160
multi_ptr(local_accessor<ElementType, Dimensions> Accessor)
152-
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
161+
: multi_ptr(Accessor.template get_multi_ptr<DecorateAddress>()) {}
153162

154163
// The following constructors are necessary to create multi_ptr<const
155164
// ElementType, Space, DecorateAddress> from accessor<ElementType, ...>.
@@ -177,8 +186,8 @@ class multi_ptr {
177186
multi_ptr(accessor<typename std::remove_const_t<RelayElementType>, Dimensions,
178187
Mode, access::target::device, isPlaceholder, PropertyListT>
179188
Accessor)
180-
: multi_ptr(
181-
detail::cast_AS<decorated_type *>(Accessor.get_pointer().get())) {}
189+
: m_Pointer(Accessor.template get_multi_ptr<DecorateAddress>()
190+
.get_decorated()) {}
182191

183192
// Only if Space == local_space || generic_space and element type is const
184193
template <int Dimensions, access::mode Mode,
@@ -208,7 +217,7 @@ class multi_ptr {
208217
multi_ptr(
209218
local_accessor<typename std::remove_const_t<RelayElementType>, Dimensions>
210219
Accessor)
211-
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
220+
: multi_ptr(Accessor.template get_multi_ptr<DecorateAddress>()) {}
212221

213222
// Assignment and access operators
214223
multi_ptr &operator=(const multi_ptr &) = default;
@@ -441,8 +450,7 @@ class multi_ptr<const void, Space, DecorateAddress> {
441450
multi_ptr(accessor<ElementType, Dimensions, Mode, access::target::device,
442451
isPlaceholder, PropertyListT>
443452
Accessor)
444-
: multi_ptr(
445-
detail::cast_AS<decorated_type *>(Accessor.get_pointer().get())) {}
453+
: multi_ptr(Accessor.template get_multi_ptr<DecorateAddress>()) {}
446454

447455
// Only if Space == local_space
448456
template <
@@ -463,7 +471,7 @@ class multi_ptr<const void, Space, DecorateAddress> {
463471
typename = typename std::enable_if_t<
464472
RelaySpace == Space && Space == access::address_space::local_space>>
465473
multi_ptr(local_accessor<ElementType, Dimensions> Accessor)
466-
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
474+
: multi_ptr(Accessor.template get_multi_ptr<DecorateAddress>()) {}
467475

468476
// Assignment operators
469477
multi_ptr &operator=(const multi_ptr &) = default;
@@ -567,8 +575,7 @@ class multi_ptr<void, Space, DecorateAddress> {
567575
multi_ptr(accessor<ElementType, Dimensions, Mode, access::target::device,
568576
isPlaceholder, PropertyListT>
569577
Accessor)
570-
: multi_ptr(
571-
detail::cast_AS<decorated_type *>(Accessor.get_pointer().get())) {}
578+
: multi_ptr(Accessor.template get_multi_ptr<DecorateAddress>()) {}
572579

573580
// Only if Space == local_space
574581
template <
@@ -589,7 +596,7 @@ class multi_ptr<void, Space, DecorateAddress> {
589596
typename = typename std::enable_if_t<
590597
RelaySpace == Space && Space == access::address_space::local_space>>
591598
multi_ptr(local_accessor<ElementType, Dimensions> Accessor)
592-
: m_Pointer(detail::cast_AS<decorated_type *>(Accessor.get_pointer())) {}
599+
: multi_ptr(Accessor.template get_multi_ptr<DecorateAddress>()) {}
593600

594601
// Assignment operators
595602
multi_ptr &operator=(const multi_ptr &) = default;
@@ -760,7 +767,7 @@ class multi_ptr<ElementType, Space, access::decorated::legacy> {
760767
multi_ptr(accessor<ElementType, dimensions, Mode, access::target::device,
761768
isPlaceholder, PropertyListT>
762769
Accessor) {
763-
m_Pointer = detail::cast_AS<pointer_t>(Accessor.get_pointer().get());
770+
m_Pointer = detail::cast_AS<pointer_t>(Accessor.get_pointer());
764771
}
765772

766773
// Only if Space == local_space || generic_space

sycl/test-e2e/Basic/multi_ptr.hpp

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ template <typename T> struct point {
3131
};
3232

3333
template <typename T, access::decorated IsDecorated>
34-
void innerFunc(id<1> wiID, global_ptr<T, IsDecorated> ptr_1,
34+
void innerFunc(id<1> wiID, global_ptr<const T, IsDecorated> ptr_1,
3535
global_ptr<T, IsDecorated> ptr_2,
3636
global_ptr<T, IsDecorated> ptr_3,
3737
global_ptr<T, IsDecorated> ptr_4,
@@ -110,9 +110,8 @@ template <typename T, access::decorated IsDecorated> void testMultPtr() {
110110
private_data[i] = 0;
111111
localAccessor[wiID.get_local_id()] = 0;
112112

113-
auto ptr_1 =
114-
multi_ptr<T, access::address_space::global_space, IsDecorated>(
115-
accessorData_1);
113+
auto ptr_1 = multi_ptr<const T, access::address_space::global_space,
114+
IsDecorated>(accessorData_1);
116115
auto ptr_2 =
117116
multi_ptr<T, access::address_space::global_space, IsDecorated>(
118117
accessorData_2);
@@ -136,19 +135,21 @@ template <typename T, access::decorated IsDecorated> void testMultPtr() {
136135

137136
// Construct extension pointer from accessors.
138137
auto dev_ptr =
139-
multi_ptr<T, access::address_space::ext_intel_global_device_space,
138+
multi_ptr<const T,
139+
access::address_space::ext_intel_global_device_space,
140140
IsDecorated>(accessorData_1);
141-
static_assert(std::is_same_v<ext::intel::device_ptr<T, IsDecorated>,
142-
decltype(dev_ptr)>,
143-
"Incorrect type for dev_ptr.");
141+
static_assert(
142+
std::is_same_v<ext::intel::device_ptr<const T, IsDecorated>,
143+
decltype(dev_ptr)>,
144+
"Incorrect type for dev_ptr.");
144145

145146
// General conversions in multi_ptr class
146147
T *RawPtr = nullptr;
147148
global_ptr<T, IsDecorated> ptr_6 =
148149
address_space_cast<access::address_space::global_space,
149150
IsDecorated>(RawPtr);
150151

151-
global_ptr<T, IsDecorated> ptr_7(accessorData_1);
152+
global_ptr<const T, IsDecorated> ptr_7(accessorData_1);
152153

153154
global_ptr<void, IsDecorated> ptr_8 =
154155
address_space_cast<access::address_space::global_space,
@@ -206,12 +207,12 @@ void testMultPtrArrowOperator() {
206207
point<T> private_val = 0;
207208

208209
auto ptr_1 =
209-
multi_ptr<point<T>, access::address_space::global_space,
210+
multi_ptr<const point<T>, access::address_space::global_space,
210211
IsDecorated>(accessorData_1);
211212
auto ptr_2 = multi_ptr<point<T>, access::address_space::local_space,
212213
IsDecorated>(accessorData_2);
213214
auto ptr_3 =
214-
multi_ptr<point<T>,
215+
multi_ptr<const point<T>,
215216
access::address_space::ext_intel_global_device_space,
216217
IsDecorated>(accessorData_3);
217218
auto ptr_4 =

sycl/test-e2e/Basic/multi_ptr_legacy.hpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include <cassert>
1010
#include <iostream>
11-
#include <sycl/sycl.hpp>
11+
#include <sycl.hpp>
1212
#include <type_traits>
1313

1414
using namespace sycl;
@@ -30,7 +30,7 @@ template <typename T> struct point {
3030
};
3131

3232
template <typename T>
33-
void innerFunc(id<1> wiID, global_ptr<T> ptr_1, global_ptr<T> ptr_2,
33+
void innerFunc(id<1> wiID, global_ptr<const T> ptr_1, global_ptr<T> ptr_2,
3434
local_ptr<T> local_ptr) {
3535
T t = ptr_1[wiID.get(0)];
3636
local_ptr[wiID.get(0)] = t;
@@ -64,31 +64,33 @@ template <typename T> void testMultPtr() {
6464

6565
cgh.parallel_for<class testMultPtrKernel<T>>(
6666
nd_range<1>{10, 10}, [=](nd_item<1> wiID) {
67-
auto ptr_1 = make_ptr<T, access::address_space::global_space,
67+
auto ptr_1 = make_ptr<const T, access::address_space::global_space,
6868
access::decorated::legacy>(
69-
accessorData_1.get_pointer());
69+
accessorData_1
70+
.template get_multi_ptr<sycl::access::decorated::legacy>());
7071
auto ptr_2 = make_ptr<T, access::address_space::global_space,
7172
access::decorated::legacy>(
72-
accessorData_2.get_pointer());
73+
accessorData_2
74+
.template get_multi_ptr<sycl::access::decorated::legacy>());
7375
auto local_ptr = make_ptr<T, access::address_space::local_space,
7476
access::decorated::legacy>(
7577
localAccessor.get_pointer());
7678

7779
// Construct extension pointer from accessors.
7880
auto dev_ptr =
79-
multi_ptr<T,
81+
multi_ptr<const T,
8082
access::address_space::ext_intel_global_device_space>(
8183
accessorData_1);
82-
static_assert(
83-
std::is_same_v<ext::intel::device_ptr<T>, decltype(dev_ptr)>,
84-
"Incorrect type for dev_ptr.");
84+
static_assert(std::is_same_v<ext::intel::device_ptr<const T>,
85+
decltype(dev_ptr)>,
86+
"Incorrect type for dev_ptr.");
8587

8688
// General conversions in multi_ptr class
8789
T *RawPtr = nullptr;
8890
global_ptr<T> ptr_4(RawPtr);
8991
ptr_4 = RawPtr;
9092

91-
global_ptr<T> ptr_5(accessorData_1);
93+
global_ptr<const T> ptr_5(accessorData_1);
9294

9395
global_ptr<void> ptr_6((void *)RawPtr);
9496

@@ -144,9 +146,11 @@ template <typename T> void testMultPtrArrowOperator() {
144146

145147
cgh.parallel_for<class testMultPtrArrowOperatorKernel<T>>(
146148
sycl::nd_range<1>{1, 1}, [=](sycl::nd_item<1>) {
147-
auto ptr_1 = make_ptr<point<T>, access::address_space::global_space,
148-
access::decorated::legacy>(
149-
accessorData_1.get_pointer());
149+
auto ptr_1 =
150+
make_ptr<const point<T>, access::address_space::global_space,
151+
access::decorated::legacy>(
152+
accessorData_1.template get_multi_ptr<
153+
sycl::access::decorated::legacy>());
150154
auto ptr_2 =
151155
make_ptr<point<T>, access::address_space::constant_space,
152156
access::decorated::legacy>(
@@ -155,7 +159,7 @@ template <typename T> void testMultPtrArrowOperator() {
155159
access::decorated::legacy>(
156160
accessorData_3.get_pointer());
157161
auto ptr_4 =
158-
make_ptr<point<T>,
162+
make_ptr<const point<T>,
159163
access::address_space::ext_intel_global_device_space,
160164
access::decorated::legacy>(
161165
accessorData_4.get_pointer());

sycl/test-e2e/GroupAlgorithm/SYCL2020/all_of.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ void test(queue q, InputContainer input, OutputContainer output,
3333
int lid = it.get_local_id(0);
3434
out[0] = all_of_group(g, pred(in[lid]));
3535
out[1] = all_of_group(g, in[lid], pred);
36-
out[2] = joint_all_of(g, in.get_pointer(), in.get_pointer() + N, pred);
36+
out[2] = joint_all_of(
37+
g, in.template get_multi_ptr<access::decorated::no>(),
38+
in.template get_multi_ptr<access::decorated::no>() + N, pred);
3739
});
3840
});
3941
}

sycl/test-e2e/GroupAlgorithm/SYCL2020/any_of.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ void test(queue q, InputContainer input, OutputContainer output,
4141
int lid = it.get_local_id(0);
4242
out[0] = any_of_group(g, pred(in[lid]));
4343
out[1] = any_of_group(g, in[lid], pred);
44-
out[2] = joint_any_of(g, in.get_pointer(), in.get_pointer() + N, pred);
44+
out[2] = joint_any_of(
45+
g, in.template get_multi_ptr<access::decorated::no>(),
46+
in.template get_multi_ptr<access::decorated::no>() + N, pred);
4547
});
4648
});
4749
}

sycl/test-e2e/GroupAlgorithm/SYCL2020/exclusive_scan.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,10 @@ void test(queue q, InputContainer input, OutputContainer output,
9292
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
9393
cgh.parallel_for<kernel_name2>(nd_range<1>(G, G), [=](nd_item<1> it) {
9494
group<1> g = it.get_group();
95-
joint_exclusive_scan(g, in.get_pointer(), in.get_pointer() + N,
96-
out.get_pointer(), binary_op);
95+
joint_exclusive_scan(
96+
g, in.template get_multi_ptr<access::decorated::no>(),
97+
in.template get_multi_ptr<access::decorated::no>() + N,
98+
out.template get_multi_ptr<access::decorated::no>(), binary_op);
9799
});
98100
});
99101
}
@@ -109,8 +111,11 @@ void test(queue q, InputContainer input, OutputContainer output,
109111
accessor out{out_buf, cgh, sycl::write_only, sycl::no_init};
110112
cgh.parallel_for<kernel_name3>(nd_range<1>(G, G), [=](nd_item<1> it) {
111113
group<1> g = it.get_group();
112-
joint_exclusive_scan(g, in.get_pointer(), in.get_pointer() + N,
113-
out.get_pointer(), init, binary_op);
114+
joint_exclusive_scan(
115+
g, in.template get_multi_ptr<access::decorated::no>(),
116+
in.template get_multi_ptr<access::decorated::no>() + N,
117+
out.template get_multi_ptr<access::decorated::no>(), init,
118+
binary_op);
114119
});
115120
});
116121
}

0 commit comments

Comments
 (0)