Skip to content

Commit 73d59ce

Browse files
authored
[SYCL] Add support for SYCL 2020 in class group (#5447)
Implement missing methods for class group according to SYCL 2020 4.9.1.7.
1 parent 7dc214a commit 73d59ce

File tree

3 files changed

+178
-23
lines changed

3 files changed

+178
-23
lines changed

sycl/include/CL/sycl/group.hpp

Lines changed: 130 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -104,53 +104,67 @@ template <int Dimensions = 1> class group {
104104

105105
group() = delete;
106106

107+
__SYCL2020_DEPRECATED("use sycl::group::get_group_id() instead")
107108
id<Dimensions> get_id() const { return index; }
108109

110+
__SYCL2020_DEPRECATED("use sycl::group::get_group_id() instead")
109111
size_t get_id(int dimension) const { return index[dimension]; }
110112

113+
id<Dimensions> get_group_id() const { return index; }
114+
115+
size_t get_group_id(int dimension) const { return index[dimension]; }
116+
111117
range<Dimensions> get_global_range() const { return globalRange; }
112118

113119
size_t get_global_range(int dimension) const {
114120
return globalRange[dimension];
115121
}
116122

123+
id<Dimensions> get_local_id() const {
124+
#ifdef __SYCL_DEVICE_ONLY__
125+
return __spirv::initLocalInvocationId<Dimensions, id<Dimensions>>();
126+
#else
127+
throw runtime_error("get_local_id() is not implemented on host device",
128+
PI_INVALID_DEVICE);
129+
// Implementing get_local_id() on host device requires ABI breaking change.
130+
// It requires extending class group with local item which represents
131+
// local_id. Currently this local id is only used in nd_item and group
132+
// cannot access it.
133+
#endif
134+
}
135+
136+
size_t get_local_linear_id() const {
137+
return get_local_linear_id_impl<Dimensions>();
138+
}
139+
117140
range<Dimensions> get_local_range() const { return localRange; }
118141

119142
size_t get_local_range(int dimension) const { return localRange[dimension]; }
120143

144+
size_t get_local_linear_range() const {
145+
return get_local_linear_range_impl();
146+
}
147+
121148
range<Dimensions> get_group_range() const { return groupRange; }
122149

123150
size_t get_group_range(int dimension) const {
124151
return get_group_range()[dimension];
125152
}
126153

154+
size_t get_group_linear_range() const {
155+
return get_group_linear_range_impl();
156+
}
157+
158+
range<Dimensions> get_max_local_range() const { return get_local_range(); }
159+
127160
size_t operator[](int dimension) const { return index[dimension]; }
128161

129-
template <int dims = Dimensions>
130-
typename detail::enable_if_t<(dims == 1), size_t> get_linear_id() const {
131-
return index[0];
132-
}
162+
__SYCL2020_DEPRECATED("use sycl::group::get_group_linear_id() instead")
163+
size_t get_linear_id() const { return get_group_linear_id(); }
133164

134-
template <int dims = Dimensions>
135-
typename detail::enable_if_t<(dims == 2), size_t> get_linear_id() const {
136-
return index[0] * groupRange[1] + index[1];
137-
}
165+
size_t get_group_linear_id() const { return get_group_linear_id_impl(); }
138166

139-
// SYCL specification 1.2.1rev5, section 4.7.6.5 "Buffer accessor":
140-
// Whenever a multi-dimensional index is passed to a SYCL accessor the
141-
// linear index is calculated based on the index {id1, id2, id3} provided
142-
// and the range of the SYCL accessor {r1, r2, r3} according to row-major
143-
// ordering as follows:
144-
// id3 + (id2 · r3) + (id1 · r3 · r2) (4.3)
145-
// section 4.8.1.8 "group class":
146-
// size_t get_linear_id()const
147-
// Get a linearized version of the work-group id. Calculating a linear
148-
// work-group id from a multi-dimensional index follows the equation 4.3.
149-
template <int dims = Dimensions>
150-
typename detail::enable_if_t<(dims == 3), size_t> get_linear_id() const {
151-
return (index[0] * groupRange[1] * groupRange[2]) +
152-
(index[1] * groupRange[2]) + index[2];
153-
}
167+
bool leader() const { return (get_local_linear_id() == 0); }
154168

155169
template <typename WorkItemFunctionT>
156170
void parallel_for_work_item(WorkItemFunctionT Func) const {
@@ -397,6 +411,99 @@ template <int Dimensions = 1> class group {
397411
range<Dimensions> groupRange;
398412
id<Dimensions> index;
399413

414+
template <int dims = Dimensions>
415+
typename detail::enable_if_t<(dims == 1), size_t>
416+
get_local_linear_id_impl() const {
417+
id<Dimensions> localId = get_local_id();
418+
return localId[0];
419+
}
420+
421+
template <int dims = Dimensions>
422+
typename detail::enable_if_t<(dims == 2), size_t>
423+
get_local_linear_id_impl() const {
424+
id<Dimensions> localId = get_local_id();
425+
return localId[0] * groupRange[1] + localId[1];
426+
}
427+
428+
template <int dims = Dimensions>
429+
typename detail::enable_if_t<(dims == 3), size_t>
430+
get_local_linear_id_impl() const {
431+
id<Dimensions> localId = get_local_id();
432+
return (localId[0] * groupRange[1] * groupRange[2]) +
433+
(localId[1] * groupRange[2]) + localId[2];
434+
}
435+
436+
template <int dims = Dimensions>
437+
typename detail::enable_if_t<(dims == 1), size_t>
438+
get_local_linear_range_impl() const {
439+
auto localRange = get_local_range();
440+
return localRange[0];
441+
}
442+
443+
template <int dims = Dimensions>
444+
typename detail::enable_if_t<(dims == 2), size_t>
445+
get_local_linear_range_impl() const {
446+
auto localRange = get_local_range();
447+
return localRange[0] * localRange[1];
448+
}
449+
450+
template <int dims = Dimensions>
451+
typename detail::enable_if_t<(dims == 3), size_t>
452+
get_local_linear_range_impl() const {
453+
auto localRange = get_local_range();
454+
return localRange[0] * localRange[1] * localRange[2];
455+
}
456+
457+
template <int dims = Dimensions>
458+
typename detail::enable_if_t<(dims == 1), size_t>
459+
get_group_linear_range_impl() const {
460+
auto groupRange = get_group_range();
461+
return groupRange[0];
462+
}
463+
464+
template <int dims = Dimensions>
465+
typename detail::enable_if_t<(dims == 2), size_t>
466+
get_group_linear_range_impl() const {
467+
auto groupRange = get_group_range();
468+
return groupRange[0] * groupRange[1];
469+
}
470+
471+
template <int dims = Dimensions>
472+
typename detail::enable_if_t<(dims == 3), size_t>
473+
get_group_linear_range_impl() const {
474+
auto groupRange = get_group_range();
475+
return groupRange[0] * groupRange[1] * groupRange[2];
476+
}
477+
478+
template <int dims = Dimensions>
479+
typename detail::enable_if_t<(dims == 1), size_t>
480+
get_group_linear_id_impl() const {
481+
return index[0];
482+
}
483+
484+
template <int dims = Dimensions>
485+
typename detail::enable_if_t<(dims == 2), size_t>
486+
get_group_linear_id_impl() const {
487+
return index[0] * groupRange[1] + index[1];
488+
}
489+
490+
// SYCL specification 1.2.1rev5, section 4.7.6.5 "Buffer accessor":
491+
// Whenever a multi-dimensional index is passed to a SYCL accessor the
492+
// linear index is calculated based on the index {id1, id2, id3} provided
493+
// and the range of the SYCL accessor {r1, r2, r3} according to row-major
494+
// ordering as follows:
495+
// id3 + (id2 · r3) + (id1 · r3 · r2) (4.3)
496+
// section 4.8.1.8 "group class":
497+
// size_t get_linear_id()const
498+
// Get a linearized version of the work-group id. Calculating a linear
499+
// work-group id from a multi-dimensional index follows the equation 4.3.
500+
template <int dims = Dimensions>
501+
typename detail::enable_if_t<(dims == 3), size_t>
502+
get_group_linear_id_impl() const {
503+
return (index[0] * groupRange[1] * groupRange[2]) +
504+
(index[1] * groupRange[2]) + index[2];
505+
}
506+
400507
void waitForHelper() const {}
401508

402509
void waitForHelper(device_event Event) const {

sycl/test/basic_tests/group.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,19 @@ int main() {
2929
assert(one_dim.get_group_range(0) == 2);
3030
assert(one_dim[0] == 1);
3131
assert(one_dim.get_linear_id() == 1);
32+
assert(one_dim.get_group_linear_id() == 1);
33+
34+
try {
35+
one_dim.get_local_id();
36+
assert(0); // get_local_id() is not implemented on host device
37+
} catch (cl::sycl::runtime_error) {
38+
}
39+
40+
try {
41+
one_dim.get_local_linear_id();
42+
assert(0); // get_local_id() is not implemented on host device
43+
} catch (cl::sycl::runtime_error) {
44+
}
3245

3346
// two dimension group
3447
cl::sycl::group<2> two_dim = Builder::createGroup<2>({8, 4}, {4, 2}, {1, 1});
@@ -47,6 +60,19 @@ int main() {
4760
assert(two_dim[0] == 1);
4861
assert(two_dim[1] == 1);
4962
assert(two_dim.get_linear_id() == 3);
63+
assert(two_dim.get_group_linear_id() == 3);
64+
65+
try {
66+
two_dim.get_local_id();
67+
assert(0); // get_local_id() is not implemented on host device
68+
} catch (cl::sycl::runtime_error) {
69+
}
70+
71+
try {
72+
two_dim.get_local_linear_id();
73+
assert(0); // get_local_id() is not implemented on host device
74+
} catch (cl::sycl::runtime_error) {
75+
}
5076

5177
// three dimension group
5278
cl::sycl::group<3> three_dim =
@@ -71,4 +97,17 @@ int main() {
7197
assert(three_dim[1] == 1);
7298
assert(three_dim[2] == 1);
7399
assert(three_dim.get_linear_id() == 7);
100+
assert(three_dim.get_group_linear_id() == 7);
101+
102+
try {
103+
three_dim.get_local_id();
104+
assert(0); // get_local_id() is not implemented on host device
105+
} catch (cl::sycl::runtime_error) {
106+
}
107+
108+
try {
109+
three_dim.get_local_linear_id();
110+
assert(0); // get_local_id() is not implemented on host device
111+
} catch (cl::sycl::runtime_error) {
112+
}
74113
}

sycl/test/warnings/sycl_2020_deprecations.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,14 @@ int main() {
179179
// expected-warning@+1 {{'atomic<int, sycl::access::address_space::global_space>' is deprecated: sycl::atomic is deprecated since SYCL 2020}}
180180
cl::sycl::atomic<int> b(a);
181181

182+
cl::sycl::group<1> group =
183+
cl::sycl::detail::Builder::createGroup<1>({8}, {4}, {1});
184+
// expected-warning@+1{{'get_id' is deprecated: use sycl::group::get_group_id() instead}}
185+
group.get_id();
186+
// expected-warning@+1{{'get_id' is deprecated: use sycl::group::get_group_id() instead}}
187+
group.get_id(1);
188+
// expected-warning@+1{{'get_linear_id' is deprecated: use sycl::group::get_group_linear_id() instead}}
189+
group.get_linear_id();
190+
182191
return 0;
183192
}

0 commit comments

Comments
 (0)