@@ -81,38 +81,45 @@ template <typename T, int Dimensions = 1> class private_memory {
81
81
#endif // #ifdef __SYCL_DEVICE_ONLY__
82
82
};
83
83
84
- template <int dimensions = 1 > class group {
84
+ template <int Dimensions = 1 > class group {
85
85
public:
86
+ #ifndef __DISABLE_SYCL_INTEL_GROUP_ALGORITHMS__
87
+ using id_type = id<Dimensions>;
88
+ using range_type = range<Dimensions>;
89
+ using linear_id_type = size_t ;
90
+ static constexpr int dimensions = Dimensions;
91
+ #endif // __DISABLE_SYCL_INTEL_GROUP_ALGORITHMS__
92
+
86
93
group () = delete ;
87
94
88
- id<dimensions > get_id () const { return index; }
95
+ id<Dimensions > get_id () const { return index; }
89
96
90
97
size_t get_id (int dimension) const { return index[dimension]; }
91
98
92
- range<dimensions > get_global_range () const { return globalRange; }
99
+ range<Dimensions > get_global_range () const { return globalRange; }
93
100
94
101
size_t get_global_range (int dimension) const {
95
102
return globalRange[dimension];
96
103
}
97
104
98
- range<dimensions > get_local_range () const { return localRange; }
105
+ range<Dimensions > get_local_range () const { return localRange; }
99
106
100
107
size_t get_local_range (int dimension) const { return localRange[dimension]; }
101
108
102
- range<dimensions > get_group_range () const { return groupRange; }
109
+ range<Dimensions > get_group_range () const { return groupRange; }
103
110
104
111
size_t get_group_range (int dimension) const {
105
112
return get_group_range ()[dimension];
106
113
}
107
114
108
115
size_t operator [](int dimension) const { return index[dimension]; }
109
116
110
- template <int dims = dimensions >
117
+ template <int dims = Dimensions >
111
118
typename std::enable_if<(dims == 1 ), size_t >::type get_linear_id () const {
112
119
return index[0 ];
113
120
}
114
121
115
- template <int dims = dimensions >
122
+ template <int dims = Dimensions >
116
123
typename std::enable_if<(dims == 2 ), size_t >::type get_linear_id () const {
117
124
return index[0 ] * groupRange[1 ] + index[1 ];
118
125
}
@@ -127,7 +134,7 @@ template <int dimensions = 1> class group {
127
134
// size_t get_linear_id()const
128
135
// Get a linearized version of the work-group id. Calculating a linear
129
136
// work-group id from a multi-dimensional index follows the equation 4.3.
130
- template <int dims = dimensions >
137
+ template <int dims = Dimensions >
131
138
typename std::enable_if<(dims == 3 ), size_t >::type get_linear_id () const {
132
139
return (index[0 ] * groupRange[1 ] * groupRange[2 ]) +
133
140
(index[1 ] * groupRange[2 ]) + index[2 ];
@@ -139,41 +146,41 @@ template <int dimensions = 1> class group {
139
146
// compilers are expected to optimize when possible
140
147
detail::workGroupBarrier ();
141
148
#ifdef __SYCL_DEVICE_ONLY__
142
- range<dimensions > GlobalSize{
143
- __spirv::initGlobalSize<dimensions , range<dimensions >>()};
144
- range<dimensions > LocalSize{
145
- __spirv::initWorkgroupSize<dimensions , range<dimensions >>()};
146
- id<dimensions > GlobalId{
147
- __spirv::initGlobalInvocationId<dimensions , id<dimensions >>()};
148
- id<dimensions > LocalId{
149
- __spirv::initLocalInvocationId<dimensions , id<dimensions >>()};
149
+ range<Dimensions > GlobalSize{
150
+ __spirv::initGlobalSize<Dimensions , range<Dimensions >>()};
151
+ range<Dimensions > LocalSize{
152
+ __spirv::initWorkgroupSize<Dimensions , range<Dimensions >>()};
153
+ id<Dimensions > GlobalId{
154
+ __spirv::initGlobalInvocationId<Dimensions , id<Dimensions >>()};
155
+ id<Dimensions > LocalId{
156
+ __spirv::initLocalInvocationId<Dimensions , id<Dimensions >>()};
150
157
151
158
// no 'iterate' in the device code variant, because
152
159
// (1) this code is already invoked by each work item as a part of the
153
160
// enclosing parallel_for_work_group kernel
154
161
// (2) the range this pfwi iterates over matches work group size exactly
155
- item<dimensions , false > GlobalItem =
156
- detail::Builder::createItem<dimensions , false >(GlobalSize, GlobalId);
157
- item<dimensions , false > LocalItem =
158
- detail::Builder::createItem<dimensions , false >(LocalSize, LocalId);
159
- h_item<dimensions > HItem =
160
- detail::Builder::createHItem<dimensions >(GlobalItem, LocalItem);
162
+ item<Dimensions , false > GlobalItem =
163
+ detail::Builder::createItem<Dimensions , false >(GlobalSize, GlobalId);
164
+ item<Dimensions , false > LocalItem =
165
+ detail::Builder::createItem<Dimensions , false >(LocalSize, LocalId);
166
+ h_item<Dimensions > HItem =
167
+ detail::Builder::createHItem<Dimensions >(GlobalItem, LocalItem);
161
168
162
169
Func (HItem);
163
170
#else
164
- id<dimensions > GroupStartID = index * localRange;
171
+ id<Dimensions > GroupStartID = index * localRange;
165
172
166
173
// ... host variant needs explicit 'iterate' because it is serial
167
- detail::NDLoop<dimensions >::iterate(
168
- localRange, [&](const id<dimensions > &LocalID) {
169
- item<dimensions , false > GlobalItem =
170
- detail::Builder::createItem<dimensions , false >(
174
+ detail::NDLoop<Dimensions >::iterate(
175
+ localRange, [&](const id<Dimensions > &LocalID) {
176
+ item<Dimensions , false > GlobalItem =
177
+ detail::Builder::createItem<Dimensions , false >(
171
178
globalRange, GroupStartID + LocalID);
172
- item<dimensions , false > LocalItem =
173
- detail::Builder::createItem<dimensions , false >(localRange,
179
+ item<Dimensions , false > LocalItem =
180
+ detail::Builder::createItem<Dimensions , false >(localRange,
174
181
LocalID);
175
- h_item<dimensions > HItem =
176
- detail::Builder::createHItem<dimensions >(GlobalItem, LocalItem);
182
+ h_item<Dimensions > HItem =
183
+ detail::Builder::createHItem<Dimensions >(GlobalItem, LocalItem);
177
184
Func (HItem);
178
185
});
179
186
#endif // __SYCL_DEVICE_ONLY__
@@ -185,52 +192,52 @@ template <int dimensions = 1> class group {
185
192
}
186
193
187
194
template <typename WorkItemFunctionT>
188
- void parallel_for_work_item (range<dimensions > flexibleRange,
195
+ void parallel_for_work_item (range<Dimensions > flexibleRange,
189
196
WorkItemFunctionT Func) const {
190
197
detail::workGroupBarrier ();
191
198
#ifdef __SYCL_DEVICE_ONLY__
192
- range<dimensions > GlobalSize{
193
- __spirv::initGlobalSize<dimensions , range<dimensions >>()};
194
- range<dimensions > LocalSize{
195
- __spirv::initWorkgroupSize<dimensions , range<dimensions >>()};
196
- id<dimensions > GlobalId{
197
- __spirv::initGlobalInvocationId<dimensions , id<dimensions >>()};
198
- id<dimensions > LocalId{
199
- __spirv::initLocalInvocationId<dimensions , id<dimensions >>()};
200
-
201
- item<dimensions , false > GlobalItem =
202
- detail::Builder::createItem<dimensions , false >(GlobalSize, GlobalId);
203
- item<dimensions , false > LocalItem =
204
- detail::Builder::createItem<dimensions , false >(LocalSize, LocalId);
205
- h_item<dimensions > HItem = detail::Builder::createHItem<dimensions >(
199
+ range<Dimensions > GlobalSize{
200
+ __spirv::initGlobalSize<Dimensions , range<Dimensions >>()};
201
+ range<Dimensions > LocalSize{
202
+ __spirv::initWorkgroupSize<Dimensions , range<Dimensions >>()};
203
+ id<Dimensions > GlobalId{
204
+ __spirv::initGlobalInvocationId<Dimensions , id<Dimensions >>()};
205
+ id<Dimensions > LocalId{
206
+ __spirv::initLocalInvocationId<Dimensions , id<Dimensions >>()};
207
+
208
+ item<Dimensions , false > GlobalItem =
209
+ detail::Builder::createItem<Dimensions , false >(GlobalSize, GlobalId);
210
+ item<Dimensions , false > LocalItem =
211
+ detail::Builder::createItem<Dimensions , false >(LocalSize, LocalId);
212
+ h_item<Dimensions > HItem = detail::Builder::createHItem<Dimensions >(
206
213
GlobalItem, LocalItem, flexibleRange);
207
214
208
215
// iterate over flexible range with work group size stride; each item
209
216
// performs flexibleRange/LocalSize iterations (if the former is divisible
210
217
// by the latter)
211
- detail::NDLoop<dimensions >::iterate(
218
+ detail::NDLoop<Dimensions >::iterate(
212
219
LocalId, LocalSize, flexibleRange,
213
- [&](const id<dimensions > &LogicalLocalID) {
220
+ [&](const id<Dimensions > &LogicalLocalID) {
214
221
HItem.setLogicalLocalID (LogicalLocalID);
215
222
Func (HItem);
216
223
});
217
224
#else
218
- id<dimensions > GroupStartID = index * localRange;
225
+ id<Dimensions > GroupStartID = index * localRange;
219
226
220
- detail::NDLoop<dimensions >::iterate(
221
- localRange, [&](const id<dimensions > &LocalID) {
222
- item<dimensions , false > GlobalItem =
223
- detail::Builder::createItem<dimensions , false >(
227
+ detail::NDLoop<Dimensions >::iterate(
228
+ localRange, [&](const id<Dimensions > &LocalID) {
229
+ item<Dimensions , false > GlobalItem =
230
+ detail::Builder::createItem<Dimensions , false >(
224
231
globalRange, GroupStartID + LocalID);
225
- item<dimensions , false > LocalItem =
226
- detail::Builder::createItem<dimensions , false >(localRange,
232
+ item<Dimensions , false > LocalItem =
233
+ detail::Builder::createItem<Dimensions , false >(localRange,
227
234
LocalID);
228
- h_item<dimensions > HItem = detail::Builder::createHItem<dimensions >(
235
+ h_item<Dimensions > HItem = detail::Builder::createHItem<Dimensions >(
229
236
GlobalItem, LocalItem, flexibleRange);
230
237
231
- detail::NDLoop<dimensions >::iterate(
238
+ detail::NDLoop<Dimensions >::iterate(
232
239
LocalID, localRange, flexibleRange,
233
- [&](const id<dimensions > &LogicalLocalID) {
240
+ [&](const id<Dimensions > &LogicalLocalID) {
234
241
HItem.setLogicalLocalID (LogicalLocalID);
235
242
Func (HItem);
236
243
});
@@ -311,23 +318,23 @@ template <int dimensions = 1> class group {
311
318
waitForHelper (Events...);
312
319
}
313
320
314
- bool operator ==(const group<dimensions > &rhs) const {
321
+ bool operator ==(const group<Dimensions > &rhs) const {
315
322
bool Result = (rhs.globalRange == globalRange) &&
316
323
(rhs.localRange == localRange) && (rhs.index == index);
317
324
__SYCL_ASSERT (rhs.groupRange == groupRange &&
318
325
" inconsistent group class fields" );
319
326
return Result;
320
327
}
321
328
322
- bool operator !=(const group<dimensions > &rhs) const {
329
+ bool operator !=(const group<Dimensions > &rhs) const {
323
330
return !((*this ) == rhs);
324
331
}
325
332
326
333
private:
327
- range<dimensions > globalRange;
328
- range<dimensions > localRange;
329
- range<dimensions > groupRange;
330
- id<dimensions > index;
334
+ range<Dimensions > globalRange;
335
+ range<Dimensions > localRange;
336
+ range<Dimensions > groupRange;
337
+ id<Dimensions > index;
331
338
332
339
void waitForHelper () const {}
333
340
@@ -343,8 +350,8 @@ template <int dimensions = 1> class group {
343
350
344
351
protected:
345
352
friend class detail ::Builder;
346
- group (const range<dimensions > &G, const range<dimensions > &L,
347
- const range<dimensions > GroupRange, const id<dimensions > &I)
353
+ group (const range<Dimensions > &G, const range<Dimensions > &L,
354
+ const range<Dimensions > GroupRange, const id<Dimensions > &I)
348
355
: globalRange(G), localRange(L), groupRange(GroupRange), index(I) {
349
356
// Make sure local range divides global without remainder:
350
357
__SYCL_ASSERT (((G % L).size () == 0 ) &&
0 commit comments