Skip to content

Commit b5d69df

Browse files
authored
[SYCL] Add E2E for prefetch & improve interface (#11834)
1 parent edd58d9 commit b5d69df

File tree

4 files changed

+348
-22
lines changed

4 files changed

+348
-22
lines changed

sycl/include/sycl/ext/oneapi/experimental/prefetch.hpp

Lines changed: 40 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,36 +100,42 @@ void joint_prefetch_impl(Group g, T *ptr, size_t bytes, Properties properties) {
100100
} // namespace detail
101101

102102
template <typename Properties = empty_properties_t>
103-
void prefetch(void *ptr, Properties properties = {}) {
103+
std::enable_if_t<is_property_list_v<std::decay_t<Properties>>>
104+
prefetch(void *ptr, Properties properties = {}) {
104105
detail::prefetch_impl(ptr, 1, properties);
105106
}
106107

107108
template <typename Properties = empty_properties_t>
108-
void prefetch(void *ptr, size_t bytes, Properties properties = {}) {
109+
std::enable_if_t<is_property_list_v<std::decay_t<Properties>>>
110+
prefetch(void *ptr, size_t bytes, Properties properties = {}) {
109111
detail::prefetch_impl(ptr, bytes, properties);
110112
}
111113

112114
template <typename T, typename Properties = empty_properties_t>
113-
void prefetch(T *ptr, Properties properties = {}) {
115+
std::enable_if_t<is_property_list_v<std::decay_t<Properties>>>
116+
prefetch(T *ptr, Properties properties = {}) {
114117
detail::prefetch_impl(ptr, sizeof(T), properties);
115118
}
116119

117120
template <typename T, typename Properties = empty_properties_t>
118-
void prefetch(T *ptr, size_t count, Properties properties = {}) {
121+
std::enable_if_t<is_property_list_v<std::decay_t<Properties>>>
122+
prefetch(T *ptr, size_t count, Properties properties = {}) {
119123
detail::prefetch_impl(ptr, count * sizeof(T), properties);
120124
}
121125

122126
template <access::address_space AddressSpace, access::decorated IsDecorated,
123127
typename Properties = empty_properties_t>
124-
std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
128+
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
129+
is_property_list_v<std::decay_t<Properties>>>
125130
prefetch(multi_ptr<void, AddressSpace, IsDecorated> ptr,
126131
Properties properties = {}) {
127132
detail::prefetch_impl(ptr.get(), 1, properties);
128133
}
129134

130135
template <access::address_space AddressSpace, access::decorated IsDecorated,
131136
typename Properties = empty_properties_t>
132-
std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
137+
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
138+
is_property_list_v<std::decay_t<Properties>>>
133139
prefetch(multi_ptr<void, AddressSpace, IsDecorated> ptr, size_t bytes,
134140
Properties properties = {}) {
135141
detail::prefetch_impl(ptr.get(), bytes, properties);
@@ -138,7 +144,8 @@ prefetch(multi_ptr<void, AddressSpace, IsDecorated> ptr, size_t bytes,
138144
template <typename T, access::address_space AddressSpace,
139145
access::decorated IsDecorated,
140146
typename Properties = empty_properties_t>
141-
std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
147+
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
148+
is_property_list_v<std::decay_t<Properties>>>
142149
prefetch(multi_ptr<T, AddressSpace, IsDecorated> ptr,
143150
Properties properties = {}) {
144151
detail::prefetch_impl(ptr.get(), sizeof(T), properties);
@@ -147,7 +154,8 @@ prefetch(multi_ptr<T, AddressSpace, IsDecorated> ptr,
147154
template <typename T, access::address_space AddressSpace,
148155
access::decorated IsDecorated,
149156
typename Properties = empty_properties_t>
150-
std::enable_if_t<detail::check_prefetch_AS<AddressSpace>>
157+
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
158+
is_property_list_v<std::decay_t<Properties>>>
151159
prefetch(multi_ptr<T, AddressSpace, IsDecorated> ptr, size_t count,
152160
Properties properties = {}) {
153161
detail::prefetch_impl(ptr.get(), count * sizeof(T), properties);
@@ -157,7 +165,8 @@ template <typename DataT, int Dimensions, access_mode AccessMode,
157165
access::placeholder IsPlaceholder,
158166
typename Properties = empty_properties_t>
159167
std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
160-
(Dimensions > 0)>
168+
(Dimensions > 0) &&
169+
is_property_list_v<std::decay_t<Properties>>>
161170
prefetch(
162171
accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
163172
id<Dimensions> offset, Properties properties = {}) {
@@ -168,33 +177,38 @@ template <typename DataT, int Dimensions, access_mode AccessMode,
168177
access::placeholder IsPlaceholder,
169178
typename Properties = empty_properties_t>
170179
std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
171-
(Dimensions > 0)>
180+
(Dimensions > 0) &&
181+
is_property_list_v<std::decay_t<Properties>>>
172182
prefetch(
173183
accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
174184
size_t offset, size_t count, Properties properties = {}) {
175185
detail::prefetch_impl(&acc[offset], count * sizeof(DataT), properties);
176186
}
177187

178188
template <typename Group, typename Properties = empty_properties_t>
179-
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
189+
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
190+
is_property_list_v<std::decay_t<Properties>>>
180191
joint_prefetch(Group g, void *ptr, Properties properties = {}) {
181192
detail::joint_prefetch_impl(g, ptr, 1, properties);
182193
}
183194

184195
template <typename Group, typename Properties = empty_properties_t>
185-
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
196+
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
197+
is_property_list_v<std::decay_t<Properties>>>
186198
joint_prefetch(Group g, void *ptr, size_t bytes, Properties properties = {}) {
187199
detail::joint_prefetch_impl(g, ptr, bytes, properties);
188200
}
189201

190202
template <typename Group, typename T, typename Properties = empty_properties_t>
191-
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
203+
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
204+
is_property_list_v<std::decay_t<Properties>>>
192205
joint_prefetch(Group g, T *ptr, Properties properties = {}) {
193206
detail::joint_prefetch_impl(g, ptr, sizeof(T), properties);
194207
}
195208

196209
template <typename Group, typename T, typename Properties = empty_properties_t>
197-
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>>>
210+
std::enable_if_t<sycl::is_group_v<std::decay_t<Group>> &&
211+
is_property_list_v<std::decay_t<Properties>>>
198212
joint_prefetch(Group g, T *ptr, size_t count, Properties properties = {}) {
199213
detail::joint_prefetch_impl(g, ptr, count * sizeof(T), properties);
200214
}
@@ -203,7 +217,8 @@ template <typename Group, access::address_space AddressSpace,
203217
access::decorated IsDecorated,
204218
typename Properties = empty_properties_t>
205219
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
206-
sycl::is_group_v<std::decay_t<Group>>>
220+
sycl::is_group_v<std::decay_t<Group>> &&
221+
is_property_list_v<std::decay_t<Properties>>>
207222
joint_prefetch(Group g, multi_ptr<void, AddressSpace, IsDecorated> ptr,
208223
Properties properties = {}) {
209224
detail::joint_prefetch_impl(g, ptr.get(), 1, properties);
@@ -213,7 +228,8 @@ template <typename Group, access::address_space AddressSpace,
213228
access::decorated IsDecorated,
214229
typename Properties = empty_properties_t>
215230
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
216-
sycl::is_group_v<std::decay_t<Group>>>
231+
sycl::is_group_v<std::decay_t<Group>> &&
232+
is_property_list_v<std::decay_t<Properties>>>
217233
joint_prefetch(Group g, multi_ptr<void, AddressSpace, IsDecorated> ptr,
218234
size_t bytes, Properties properties = {}) {
219235
detail::joint_prefetch_impl(g, ptr.get(), bytes, properties);
@@ -223,7 +239,8 @@ template <typename Group, typename T, access::address_space AddressSpace,
223239
access::decorated IsDecorated,
224240
typename Properties = empty_properties_t>
225241
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
226-
sycl::is_group_v<std::decay_t<Group>>>
242+
sycl::is_group_v<std::decay_t<Group>> &&
243+
is_property_list_v<std::decay_t<Properties>>>
227244
joint_prefetch(Group g, multi_ptr<T, AddressSpace, IsDecorated> ptr,
228245
Properties properties = {}) {
229246
detail::joint_prefetch_impl(g, ptr.get(), sizeof(T), properties);
@@ -233,7 +250,8 @@ template <typename Group, typename T, access::address_space AddressSpace,
233250
access::decorated IsDecorated,
234251
typename Properties = empty_properties_t>
235252
std::enable_if_t<detail::check_prefetch_AS<AddressSpace> &&
236-
sycl::is_group_v<std::decay_t<Group>>>
253+
sycl::is_group_v<std::decay_t<Group>> &&
254+
is_property_list_v<std::decay_t<Properties>>>
237255
joint_prefetch(Group g, multi_ptr<T, AddressSpace, IsDecorated> ptr,
238256
size_t count, Properties properties = {}) {
239257
detail::joint_prefetch_impl(g, ptr.get(), count * sizeof(T), properties);
@@ -243,7 +261,8 @@ template <typename Group, typename DataT, int Dimensions,
243261
access_mode AccessMode, access::placeholder IsPlaceholder,
244262
typename Properties = empty_properties_t>
245263
std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
246-
(Dimensions > 0) && sycl::is_group_v<std::decay_t<Group>>>
264+
(Dimensions > 0) && sycl::is_group_v<std::decay_t<Group>> &&
265+
is_property_list_v<std::decay_t<Properties>>>
247266
joint_prefetch(
248267
Group g,
249268
accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
@@ -255,7 +274,8 @@ template <typename Group, typename DataT, int Dimensions,
255274
access_mode AccessMode, access::placeholder IsPlaceholder,
256275
typename Properties = empty_properties_t>
257276
std::enable_if_t<detail::check_prefetch_acc_mode<AccessMode> &&
258-
(Dimensions > 0) && sycl::is_group_v<std::decay_t<Group>>>
277+
(Dimensions > 0) && sycl::is_group_v<std::decay_t<Group>> &&
278+
is_property_list_v<std::decay_t<Properties>>>
259279
joint_prefetch(
260280
Group g,
261281
accessor<DataT, Dimensions, AccessMode, target::device, IsPlaceholder> acc,
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
// REQUIRES: gpu && (level_zero || opencl)
2+
// RUN: %{build} -o %t.out
3+
// RUN: %{run} %t.out
4+
5+
#include <numeric>
6+
#include <sycl/sycl.hpp>
7+
8+
using namespace sycl;
9+
namespace syclex = sycl::ext::oneapi::experimental;
10+
11+
constexpr size_t N = 128;
12+
constexpr size_t NumWI = 4;
13+
constexpr size_t arrSize = N / NumWI;
14+
15+
#define COMMA ,
16+
17+
#define TEST_PREFETCH_WO_COUNT(PREFETCH_ARG) \
18+
{ \
19+
std::vector<int> res(N); \
20+
{ \
21+
buffer<int, 1> buf(res.data(), N); \
22+
q.submit([&](handler &h) { \
23+
auto acc = buf.get_access<access_mode::write>(h); \
24+
h.parallel_for( \
25+
nd_range<1>(range<1>(NumWI), range<1>(NumWI)), \
26+
[=](nd_item<1> idx) { \
27+
syclex::joint_prefetch( \
28+
idx.get_group(), PREFETCH_ARG, \
29+
syclex::properties{syclex::prefetch_hint_L1}); \
30+
for (int i = idx.get_local_linear_id() * arrSize; \
31+
i < idx.get_local_linear_id() * arrSize + arrSize; i++) \
32+
acc[i] = dataChar[0] * dataChar[0]; \
33+
}); \
34+
}); \
35+
q.wait(); \
36+
} \
37+
for (int i = 0; i < N; i++) \
38+
assert(res[i] == dataChar[0] * dataChar[0]); \
39+
}
40+
41+
#define TEST_PREFETCH_W_COUNT(PREFETCH_ARG) \
42+
{ \
43+
std::vector<int> res(N); \
44+
{ \
45+
buffer<int, 1> buf(res.data(), N); \
46+
q.submit([&](handler &h) { \
47+
auto acc = buf.get_access<access_mode::write>(h); \
48+
h.parallel_for( \
49+
nd_range<1>(range<1>(NumWI), range<1>(NumWI)), \
50+
[=](nd_item<1> idx) { \
51+
syclex::joint_prefetch( \
52+
idx.get_group(), PREFETCH_ARG, arrSize, \
53+
syclex::properties{syclex::prefetch_hint_L1}); \
54+
for (int i = idx.get_local_linear_id() * arrSize; \
55+
i < idx.get_local_linear_id() * arrSize + arrSize; i++) \
56+
acc[i] = dataChar[i] * dataChar[i]; \
57+
}); \
58+
}); \
59+
q.wait(); \
60+
} \
61+
for (int i = 0; i < N; i++) \
62+
assert(res[i] == dataChar[i] * dataChar[i]); \
63+
}
64+
65+
void testPrefetchWithAcc(queue q, const std::vector<int> &data,
66+
bool prefetchOneElem = true) {
67+
std::vector<int> res(N);
68+
{
69+
buffer<int, 1> bufRes(res.data(), N);
70+
buffer<int, 1> bufData(data.data(), N);
71+
q.submit([&](handler &h) {
72+
auto accRes = bufRes.get_access<access_mode::write>(h);
73+
auto accData = bufData.get_access<access_mode::read>(h);
74+
h.parallel_for(
75+
nd_range<1>(range<1>(NumWI), range<1>(NumWI)), [=](nd_item<1> idx) {
76+
if (prefetchOneElem)
77+
syclex::joint_prefetch(
78+
idx.get_group(), accData, id(0),
79+
syclex::properties{syclex::prefetch_hint_L1});
80+
else
81+
syclex::joint_prefetch(
82+
idx.get_group(), accData,
83+
id(idx.get_local_linear_id() * arrSize), arrSize,
84+
syclex::properties{syclex::prefetch_hint_L1});
85+
for (int i = idx.get_local_linear_id() * arrSize;
86+
i < idx.get_local_linear_id() * arrSize + arrSize; i++)
87+
accRes[i] = prefetchOneElem ? accData[0] * accData[0]
88+
: accData[i] * accData[i];
89+
});
90+
});
91+
q.wait();
92+
}
93+
for (int i = 0; i < N; i++)
94+
assert(res[i] == (prefetchOneElem ? data[0] * data[0] : data[i] * data[i]));
95+
}
96+
97+
int main() {
98+
queue q;
99+
100+
if (q.get_device().has(aspect::usm_shared_allocations)) {
101+
auto *dataChar = malloc_shared<char>(N, q);
102+
auto *dataVoid = reinterpret_cast<void *>(dataChar);
103+
auto mPtrChar = address_space_cast<access::address_space::global_space,
104+
access::decorated::yes>(dataChar);
105+
auto mPtrVoid = address_space_cast<access::address_space::global_space,
106+
access::decorated::yes>(dataVoid);
107+
108+
std::iota(dataChar, dataChar + N, 0);
109+
110+
// void prefetch(void* ptr, Properties properties = {});
111+
TEST_PREFETCH_WO_COUNT(dataVoid)
112+
// void prefetch(T* ptr, Properties properties = {});
113+
TEST_PREFETCH_WO_COUNT(dataChar)
114+
// void prefetch(void* ptr, size_t bytes, Properties properties = {});
115+
TEST_PREFETCH_W_COUNT(reinterpret_cast<void *>(
116+
&dataChar[idx.get_local_linear_id() * arrSize]))
117+
// void prefetch(T* ptr, size_t count, Properties properties = {});
118+
TEST_PREFETCH_W_COUNT(&dataChar[idx.get_local_linear_id() * arrSize])
119+
// void prefetch(multi_ptr<void, AddressSpace, IsDecorated> ptr,
120+
// Properties properties = {});
121+
TEST_PREFETCH_WO_COUNT(mPtrVoid)
122+
// void prefetch(multi_ptr<T, AddressSpace, IsDecorated> ptr,
123+
// Properties properties = {});
124+
TEST_PREFETCH_WO_COUNT(mPtrChar)
125+
// void prefetch(multi_ptr<void, AddressSpace, IsDecorated> ptr,
126+
// size_t bytes, Properties properties = {});
127+
TEST_PREFETCH_W_COUNT(
128+
address_space_cast<access::address_space::global_space COMMA
129+
access::decorated::yes>(reinterpret_cast<void *>(
130+
&dataChar[idx.get_local_linear_id() * arrSize])))
131+
// void prefetch(multi_ptr<T, AddressSpace, IsDecorated> ptr, size_t count,
132+
// Properties properties = {});
133+
TEST_PREFETCH_W_COUNT(
134+
address_space_cast<
135+
access::address_space::global_space COMMA access::decorated::yes>(
136+
&dataChar[idx.get_local_linear_id() * arrSize]))
137+
}
138+
{
139+
std::vector<int> data(N);
140+
std::iota(data.begin(), data.end(), 0);
141+
142+
// void prefetch(accessor<DataT, Dimensions, AccessMode, target::device,
143+
// IsPlaceholder> acc, id<Dimensions> offset,
144+
// Properties properties = {});
145+
testPrefetchWithAcc(q, data);
146+
// void prefetch(accessor<DataT, Dimensions, AccessMode, target::device,
147+
// IsPlaceholder> acc, id<Dimensions> offset, size_t count,
148+
// Properties properties = {});
149+
testPrefetchWithAcc(q, data, false);
150+
}
151+
}

0 commit comments

Comments
 (0)