Skip to content

Commit 8bfa107

Browse files
authored
[SYCL] Add prototype of group algorithms (#1236)
Exposes group collectives: - all_of - any_of - none_of - reduce - exclusive_scan - inclusive_scan This prototype does not support the host device. Co-Authored-By: Roland Schulz <[email protected]> Co-Authored-By: Alexey Sachkov <[email protected]> Signed-off-by: John Pennycook <[email protected]>
1 parent d3b00d0 commit 8bfa107

File tree

16 files changed

+1583
-140
lines changed

16 files changed

+1583
-140
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ extern SYCL_EXTERNAL bool __spirv_GroupAny(__spv::Scope Execution,
191191
template <typename dataT>
192192
extern SYCL_EXTERNAL dataT __spirv_GroupBroadcast(__spv::Scope Execution,
193193
dataT Value,
194-
uint32_t LocalId) noexcept;
194+
size_t LocalId) noexcept;
195195

196196
template <typename dataT>
197197
extern SYCL_EXTERNAL dataT

sycl/include/CL/sycl.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@
2323
#include <CL/sycl/image.hpp>
2424
#include <CL/sycl/intel/builtins.hpp>
2525
#include <CL/sycl/intel/function_pointer.hpp>
26+
#include <CL/sycl/intel/group_algorithm.hpp>
2627
#include <CL/sycl/intel/sub_group.hpp>
2728
#include <CL/sycl/item.hpp>
2829
#include <CL/sycl/kernel.hpp>
2930
#include <CL/sycl/multi_ptr.hpp>
3031
#include <CL/sycl/nd_item.hpp>
3132
#include <CL/sycl/nd_range.hpp>
33+
#include <CL/sycl/ordered_queue.hpp>
3234
#include <CL/sycl/pipes.hpp>
3335
#include <CL/sycl/platform.hpp>
3436
#include <CL/sycl/pointers.hpp>
3537
#include <CL/sycl/program.hpp>
3638
#include <CL/sycl/queue.hpp>
37-
#include <CL/sycl/ordered_queue.hpp>
3839
#include <CL/sycl/range.hpp>
3940
#include <CL/sycl/sampler.hpp>
4041
#include <CL/sycl/stream.hpp>
4142
#include <CL/sycl/types.hpp>
4243
#include <CL/sycl/usm.hpp>
4344
#include <CL/sycl/version.hpp>
44-

sycl/include/CL/sycl/detail/spirv.hpp

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===-- spirv.hpp - Helpers to generate SPIR-V instructions ----*- C++ -*--===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#pragma once
10+
#include <CL/__spirv/spirv_ops.hpp>
11+
#include <CL/__spirv/spirv_types.hpp>
12+
#include <CL/__spirv/spirv_vars.hpp>
13+
#include <CL/sycl/detail/generic_type_traits.hpp>
14+
#include <CL/sycl/detail/type_traits.hpp>
15+
16+
#ifdef __SYCL_DEVICE_ONLY__
17+
__SYCL_INLINE_NAMESPACE(cl) {
18+
namespace sycl {
19+
namespace detail {
20+
namespace spirv {
21+
22+
// Broadcast with scalar local index
23+
template <__spv::Scope S, typename T, typename IdT>
24+
detail::enable_if_t<std::is_integral<IdT>::value, T>
25+
GroupBroadcast(T x, IdT local_id) {
26+
using OCLT = detail::ConvertToOpenCLType_t<T>;
27+
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
28+
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
29+
OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(local_id);
30+
return __spirv_GroupBroadcast(S, ocl_x, ocl_id);
31+
}
32+
33+
// Broadcast with vector local index
34+
template <__spv::Scope S, typename T, int Dimensions>
35+
T GroupBroadcast(T x, id<Dimensions> local_id) {
36+
if (Dimensions == 1) {
37+
return GroupBroadcast<S>(x, local_id[0]);
38+
}
39+
using IdT = vec<size_t, Dimensions>;
40+
using OCLT = detail::ConvertToOpenCLType_t<T>;
41+
using OCLIdT = detail::ConvertToOpenCLType_t<IdT>;
42+
IdT vec_id;
43+
for (int i = 0; i < Dimensions; ++i) {
44+
vec_id[i] = local_id[Dimensions - i - 1];
45+
}
46+
OCLT ocl_x = detail::convertDataToType<T, OCLT>(x);
47+
OCLIdT ocl_id = detail::convertDataToType<IdT, OCLIdT>(vec_id);
48+
return __spirv_GroupBroadcast(S, ocl_x, ocl_id);
49+
}
50+
51+
} // namespace spirv
52+
} // namespace detail
53+
} // namespace sycl
54+
} // __SYCL_INLINE_NAMESPACE(cl)
55+
#endif // __SYCL_DEVICE_ONLY__

sycl/include/CL/sycl/detail/type_traits.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,14 @@ template <typename T>
195195
struct is_arithmetic
196196
: bool_constant<is_integral<T>::value || is_floating_point<T>::value> {};
197197

198+
template <typename T>
199+
struct is_scalar_arithmetic
200+
: bool_constant<!is_vec<T>::value && is_arithmetic<T>::value> {};
201+
202+
template <typename T>
203+
struct is_vector_arithmetic
204+
: bool_constant<is_vec<T>::value && is_arithmetic<T>::value> {};
205+
198206
// is_pointer
199207
template <typename T> struct is_pointer_impl : std::false_type {};
200208

sycl/include/CL/sycl/group.hpp

Lines changed: 73 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -81,38 +81,45 @@ template <typename T, int Dimensions = 1> class private_memory {
8181
#endif // #ifdef __SYCL_DEVICE_ONLY__
8282
};
8383

84-
template <int dimensions = 1> class group {
84+
template <int Dimensions = 1> class group {
8585
public:
86+
#ifndef __DISABLE_SYCL_INTEL_GROUP_ALGORITHMS__
87+
using id_type = id<Dimensions>;
88+
using range_type = range<Dimensions>;
89+
using linear_id_type = size_t;
90+
static constexpr int dimensions = Dimensions;
91+
#endif // __DISABLE_SYCL_INTEL_GROUP_ALGORITHMS__
92+
8693
group() = delete;
8794

88-
id<dimensions> get_id() const { return index; }
95+
id<Dimensions> get_id() const { return index; }
8996

9097
size_t get_id(int dimension) const { return index[dimension]; }
9198

92-
range<dimensions> get_global_range() const { return globalRange; }
99+
range<Dimensions> get_global_range() const { return globalRange; }
93100

94101
size_t get_global_range(int dimension) const {
95102
return globalRange[dimension];
96103
}
97104

98-
range<dimensions> get_local_range() const { return localRange; }
105+
range<Dimensions> get_local_range() const { return localRange; }
99106

100107
size_t get_local_range(int dimension) const { return localRange[dimension]; }
101108

102-
range<dimensions> get_group_range() const { return groupRange; }
109+
range<Dimensions> get_group_range() const { return groupRange; }
103110

104111
size_t get_group_range(int dimension) const {
105112
return get_group_range()[dimension];
106113
}
107114

108115
size_t operator[](int dimension) const { return index[dimension]; }
109116

110-
template <int dims = dimensions>
117+
template <int dims = Dimensions>
111118
typename std::enable_if<(dims == 1), size_t>::type get_linear_id() const {
112119
return index[0];
113120
}
114121

115-
template <int dims = dimensions>
122+
template <int dims = Dimensions>
116123
typename std::enable_if<(dims == 2), size_t>::type get_linear_id() const {
117124
return index[0] * groupRange[1] + index[1];
118125
}
@@ -127,7 +134,7 @@ template <int dimensions = 1> class group {
127134
// size_t get_linear_id()const
128135
// Get a linearized version of the work-group id. Calculating a linear
129136
// work-group id from a multi-dimensional index follows the equation 4.3.
130-
template <int dims = dimensions>
137+
template <int dims = Dimensions>
131138
typename std::enable_if<(dims == 3), size_t>::type get_linear_id() const {
132139
return (index[0] * groupRange[1] * groupRange[2]) +
133140
(index[1] * groupRange[2]) + index[2];
@@ -139,41 +146,41 @@ template <int dimensions = 1> class group {
139146
// compilers are expected to optimize when possible
140147
detail::workGroupBarrier();
141148
#ifdef __SYCL_DEVICE_ONLY__
142-
range<dimensions> GlobalSize{
143-
__spirv::initGlobalSize<dimensions, range<dimensions>>()};
144-
range<dimensions> LocalSize{
145-
__spirv::initWorkgroupSize<dimensions, range<dimensions>>()};
146-
id<dimensions> GlobalId{
147-
__spirv::initGlobalInvocationId<dimensions, id<dimensions>>()};
148-
id<dimensions> LocalId{
149-
__spirv::initLocalInvocationId<dimensions, id<dimensions>>()};
149+
range<Dimensions> GlobalSize{
150+
__spirv::initGlobalSize<Dimensions, range<Dimensions>>()};
151+
range<Dimensions> LocalSize{
152+
__spirv::initWorkgroupSize<Dimensions, range<Dimensions>>()};
153+
id<Dimensions> GlobalId{
154+
__spirv::initGlobalInvocationId<Dimensions, id<Dimensions>>()};
155+
id<Dimensions> LocalId{
156+
__spirv::initLocalInvocationId<Dimensions, id<Dimensions>>()};
150157

151158
// no 'iterate' in the device code variant, because
152159
// (1) this code is already invoked by each work item as a part of the
153160
// enclosing parallel_for_work_group kernel
154161
// (2) the range this pfwi iterates over matches work group size exactly
155-
item<dimensions, false> GlobalItem =
156-
detail::Builder::createItem<dimensions, false>(GlobalSize, GlobalId);
157-
item<dimensions, false> LocalItem =
158-
detail::Builder::createItem<dimensions, false>(LocalSize, LocalId);
159-
h_item<dimensions> HItem =
160-
detail::Builder::createHItem<dimensions>(GlobalItem, LocalItem);
162+
item<Dimensions, false> GlobalItem =
163+
detail::Builder::createItem<Dimensions, false>(GlobalSize, GlobalId);
164+
item<Dimensions, false> LocalItem =
165+
detail::Builder::createItem<Dimensions, false>(LocalSize, LocalId);
166+
h_item<Dimensions> HItem =
167+
detail::Builder::createHItem<Dimensions>(GlobalItem, LocalItem);
161168

162169
Func(HItem);
163170
#else
164-
id<dimensions> GroupStartID = index * localRange;
171+
id<Dimensions> GroupStartID = index * localRange;
165172

166173
// ... host variant needs explicit 'iterate' because it is serial
167-
detail::NDLoop<dimensions>::iterate(
168-
localRange, [&](const id<dimensions> &LocalID) {
169-
item<dimensions, false> GlobalItem =
170-
detail::Builder::createItem<dimensions, false>(
174+
detail::NDLoop<Dimensions>::iterate(
175+
localRange, [&](const id<Dimensions> &LocalID) {
176+
item<Dimensions, false> GlobalItem =
177+
detail::Builder::createItem<Dimensions, false>(
171178
globalRange, GroupStartID + LocalID);
172-
item<dimensions, false> LocalItem =
173-
detail::Builder::createItem<dimensions, false>(localRange,
179+
item<Dimensions, false> LocalItem =
180+
detail::Builder::createItem<Dimensions, false>(localRange,
174181
LocalID);
175-
h_item<dimensions> HItem =
176-
detail::Builder::createHItem<dimensions>(GlobalItem, LocalItem);
182+
h_item<Dimensions> HItem =
183+
detail::Builder::createHItem<Dimensions>(GlobalItem, LocalItem);
177184
Func(HItem);
178185
});
179186
#endif // __SYCL_DEVICE_ONLY__
@@ -185,52 +192,52 @@ template <int dimensions = 1> class group {
185192
}
186193

187194
template <typename WorkItemFunctionT>
188-
void parallel_for_work_item(range<dimensions> flexibleRange,
195+
void parallel_for_work_item(range<Dimensions> flexibleRange,
189196
WorkItemFunctionT Func) const {
190197
detail::workGroupBarrier();
191198
#ifdef __SYCL_DEVICE_ONLY__
192-
range<dimensions> GlobalSize{
193-
__spirv::initGlobalSize<dimensions, range<dimensions>>()};
194-
range<dimensions> LocalSize{
195-
__spirv::initWorkgroupSize<dimensions, range<dimensions>>()};
196-
id<dimensions> GlobalId{
197-
__spirv::initGlobalInvocationId<dimensions, id<dimensions>>()};
198-
id<dimensions> LocalId{
199-
__spirv::initLocalInvocationId<dimensions, id<dimensions>>()};
200-
201-
item<dimensions, false> GlobalItem =
202-
detail::Builder::createItem<dimensions, false>(GlobalSize, GlobalId);
203-
item<dimensions, false> LocalItem =
204-
detail::Builder::createItem<dimensions, false>(LocalSize, LocalId);
205-
h_item<dimensions> HItem = detail::Builder::createHItem<dimensions>(
199+
range<Dimensions> GlobalSize{
200+
__spirv::initGlobalSize<Dimensions, range<Dimensions>>()};
201+
range<Dimensions> LocalSize{
202+
__spirv::initWorkgroupSize<Dimensions, range<Dimensions>>()};
203+
id<Dimensions> GlobalId{
204+
__spirv::initGlobalInvocationId<Dimensions, id<Dimensions>>()};
205+
id<Dimensions> LocalId{
206+
__spirv::initLocalInvocationId<Dimensions, id<Dimensions>>()};
207+
208+
item<Dimensions, false> GlobalItem =
209+
detail::Builder::createItem<Dimensions, false>(GlobalSize, GlobalId);
210+
item<Dimensions, false> LocalItem =
211+
detail::Builder::createItem<Dimensions, false>(LocalSize, LocalId);
212+
h_item<Dimensions> HItem = detail::Builder::createHItem<Dimensions>(
206213
GlobalItem, LocalItem, flexibleRange);
207214

208215
// iterate over flexible range with work group size stride; each item
209216
// performs flexibleRange/LocalSize iterations (if the former is divisible
210217
// by the latter)
211-
detail::NDLoop<dimensions>::iterate(
218+
detail::NDLoop<Dimensions>::iterate(
212219
LocalId, LocalSize, flexibleRange,
213-
[&](const id<dimensions> &LogicalLocalID) {
220+
[&](const id<Dimensions> &LogicalLocalID) {
214221
HItem.setLogicalLocalID(LogicalLocalID);
215222
Func(HItem);
216223
});
217224
#else
218-
id<dimensions> GroupStartID = index * localRange;
225+
id<Dimensions> GroupStartID = index * localRange;
219226

220-
detail::NDLoop<dimensions>::iterate(
221-
localRange, [&](const id<dimensions> &LocalID) {
222-
item<dimensions, false> GlobalItem =
223-
detail::Builder::createItem<dimensions, false>(
227+
detail::NDLoop<Dimensions>::iterate(
228+
localRange, [&](const id<Dimensions> &LocalID) {
229+
item<Dimensions, false> GlobalItem =
230+
detail::Builder::createItem<Dimensions, false>(
224231
globalRange, GroupStartID + LocalID);
225-
item<dimensions, false> LocalItem =
226-
detail::Builder::createItem<dimensions, false>(localRange,
232+
item<Dimensions, false> LocalItem =
233+
detail::Builder::createItem<Dimensions, false>(localRange,
227234
LocalID);
228-
h_item<dimensions> HItem = detail::Builder::createHItem<dimensions>(
235+
h_item<Dimensions> HItem = detail::Builder::createHItem<Dimensions>(
229236
GlobalItem, LocalItem, flexibleRange);
230237

231-
detail::NDLoop<dimensions>::iterate(
238+
detail::NDLoop<Dimensions>::iterate(
232239
LocalID, localRange, flexibleRange,
233-
[&](const id<dimensions> &LogicalLocalID) {
240+
[&](const id<Dimensions> &LogicalLocalID) {
234241
HItem.setLogicalLocalID(LogicalLocalID);
235242
Func(HItem);
236243
});
@@ -311,23 +318,23 @@ template <int dimensions = 1> class group {
311318
waitForHelper(Events...);
312319
}
313320

314-
bool operator==(const group<dimensions> &rhs) const {
321+
bool operator==(const group<Dimensions> &rhs) const {
315322
bool Result = (rhs.globalRange == globalRange) &&
316323
(rhs.localRange == localRange) && (rhs.index == index);
317324
__SYCL_ASSERT(rhs.groupRange == groupRange &&
318325
"inconsistent group class fields");
319326
return Result;
320327
}
321328

322-
bool operator!=(const group<dimensions> &rhs) const {
329+
bool operator!=(const group<Dimensions> &rhs) const {
323330
return !((*this) == rhs);
324331
}
325332

326333
private:
327-
range<dimensions> globalRange;
328-
range<dimensions> localRange;
329-
range<dimensions> groupRange;
330-
id<dimensions> index;
334+
range<Dimensions> globalRange;
335+
range<Dimensions> localRange;
336+
range<Dimensions> groupRange;
337+
id<Dimensions> index;
331338

332339
void waitForHelper() const {}
333340

@@ -343,8 +350,8 @@ template <int dimensions = 1> class group {
343350

344351
protected:
345352
friend class detail::Builder;
346-
group(const range<dimensions> &G, const range<dimensions> &L,
347-
const range<dimensions> GroupRange, const id<dimensions> &I)
353+
group(const range<Dimensions> &G, const range<Dimensions> &L,
354+
const range<Dimensions> GroupRange, const id<Dimensions> &I)
348355
: globalRange(G), localRange(L), groupRange(GroupRange), index(I) {
349356
// Make sure local range divides global without remainder:
350357
__SYCL_ASSERT(((G % L).size() == 0) &&

0 commit comments

Comments
 (0)