@@ -27,6 +27,8 @@ struct sub_group;
27
27
namespace experimental {
28
28
template <typename ParentGroup> class ballot_group ;
29
29
template <size_t PartitionSize, typename ParentGroup> class fixed_size_group ;
30
+ template <typename ParentGroup> class tangle_group ;
31
+ class opportunistic_group ;
30
32
} // namespace experimental
31
33
} // namespace oneapi
32
34
} // namespace ext
@@ -72,6 +74,16 @@ struct group_scope<sycl::ext::oneapi::experimental::fixed_size_group<
72
74
static constexpr __spv::Scope::Flag value = group_scope<ParentGroup>::value;
73
75
};
74
76
77
+ template <typename ParentGroup>
78
+ struct group_scope <sycl::ext::oneapi::experimental::tangle_group<ParentGroup>> {
79
+ static constexpr __spv::Scope::Flag value = group_scope<ParentGroup>::value;
80
+ };
81
+
82
+ template <>
83
+ struct group_scope <::sycl::ext::oneapi::experimental::opportunistic_group> {
84
+ static constexpr __spv::Scope::Flag value = __spv::Scope::Flag::Subgroup;
85
+ };
86
+
75
87
// Generic shuffles and broadcasts may require multiple calls to
76
88
// intrinsics, and should use the fewest broadcasts possible
77
89
// - Loop over chunks until remaining bytes < chunk size
@@ -135,6 +147,16 @@ bool GroupAll(
135
147
static_cast <uint32_t >(__spv::GroupOperation::ClusteredReduce),
136
148
static_cast <uint32_t >(pred), PartitionSize);
137
149
}
150
+ template <typename ParentGroup>
151
+ bool GroupAll (ext::oneapi::experimental::tangle_group<ParentGroup>, bool pred) {
152
+ return __spirv_GroupNonUniformAll (group_scope<ParentGroup>::value, pred);
153
+ }
154
+ template <typename Group>
155
+ bool GroupAll (const ext::oneapi::experimental::opportunistic_group &,
156
+ bool pred) {
157
+ return __spirv_GroupNonUniformAll (
158
+ group_scope<ext::oneapi::experimental::opportunistic_group>::value, pred);
159
+ }
138
160
139
161
template <typename Group> bool GroupAny (Group, bool pred) {
140
162
return __spirv_GroupAny (group_scope<Group>::value, pred);
@@ -161,6 +183,15 @@ bool GroupAny(
161
183
static_cast <uint32_t >(__spv::GroupOperation::ClusteredReduce),
162
184
static_cast <uint32_t >(pred), PartitionSize);
163
185
}
186
+ template <typename ParentGroup>
187
+ bool GroupAny (ext::oneapi::experimental::tangle_group<ParentGroup>, bool pred) {
188
+ return __spirv_GroupNonUniformAny (group_scope<ParentGroup>::value, pred);
189
+ }
190
+ bool GroupAny (const ext::oneapi::experimental::opportunistic_group &,
191
+ bool pred) {
192
+ return __spirv_GroupNonUniformAny (
193
+ group_scope<ext::oneapi::experimental::opportunistic_group>::value, pred);
194
+ }
164
195
165
196
// Native broadcasts map directly to a SPIR-V GroupBroadcast intrinsic
166
197
// FIXME: Do not special-case for half or vec once all backends support all data
@@ -281,6 +312,45 @@ EnableIfNativeBroadcast<T, IdT> GroupBroadcast(
281
312
return __spirv_GroupNonUniformShuffle (group_scope<ParentGroup>::value, OCLX,
282
313
OCLId);
283
314
}
315
+ template <typename ParentGroup, typename T, typename IdT>
316
+ EnableIfNativeBroadcast<T, IdT>
317
+ GroupBroadcast (ext::oneapi::experimental::tangle_group<ParentGroup> g, T x,
318
+ IdT local_id) {
319
+ // Remap local_id to its original numbering in ParentGroup.
320
+ auto LocalId = detail::IdToMaskPosition (g, local_id);
321
+
322
+ // TODO: Refactor to avoid duplication after design settles.
323
+ using GroupIdT = typename GroupId<ParentGroup>::type;
324
+ GroupIdT GroupLocalId = static_cast <GroupIdT>(LocalId);
325
+ using OCLT = detail::ConvertToOpenCLType_t<T>;
326
+ using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
327
+ using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
328
+ WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
329
+ OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
330
+
331
+ return __spirv_GroupNonUniformBroadcast (group_scope<ParentGroup>::value, OCLX,
332
+ OCLId);
333
+ }
334
+ template <typename T, typename IdT>
335
+ EnableIfNativeBroadcast<T, IdT>
336
+ GroupBroadcast (const ext::oneapi::experimental::opportunistic_group &g, T x,
337
+ IdT local_id) {
338
+ // Remap local_id to its original numbering in sub-group
339
+ auto LocalId = detail::IdToMaskPosition (g, local_id);
340
+
341
+ // TODO: Refactor to avoid duplication after design settles.
342
+ using GroupIdT = typename GroupId<sycl::ext::oneapi::sub_group>::type;
343
+ GroupIdT GroupLocalId = static_cast <GroupIdT>(LocalId);
344
+ using OCLT = detail::ConvertToOpenCLType_t<T>;
345
+ using WidenedT = WidenOpenCLTypeTo32_t<OCLT>;
346
+ using OCLIdT = detail::ConvertToOpenCLType_t<GroupIdT>;
347
+ WidenedT OCLX = detail::convertDataToType<T, OCLT>(x);
348
+ OCLIdT OCLId = detail::convertDataToType<GroupIdT, OCLIdT>(GroupLocalId);
349
+
350
+ return __spirv_GroupNonUniformBroadcast (
351
+ group_scope<ext::oneapi::experimental::opportunistic_group>::value, OCLX,
352
+ OCLId);
353
+ }
284
354
285
355
template <typename Group, typename T, typename IdT>
286
356
EnableIfBitcastBroadcast<T, IdT> GroupBroadcast (Group g, T x, IdT local_id) {
@@ -956,6 +1026,18 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
956
1026
#endif
957
1027
}
958
1028
1029
+ template <typename Group>
1030
+ struct is_tangle_or_opportunistic_group : std::false_type {};
1031
+
1032
+ template <typename ParentGroup>
1033
+ struct is_tangle_or_opportunistic_group <
1034
+ sycl::ext::oneapi::experimental::tangle_group<ParentGroup>>
1035
+ : std::true_type {};
1036
+
1037
+ template <>
1038
+ struct is_tangle_or_opportunistic_group <
1039
+ sycl::ext::oneapi::experimental::opportunistic_group> : std::true_type {};
1040
+
959
1041
// TODO: Refactor to avoid duplication after design settles
960
1042
#define __SYCL_GROUP_COLLECTIVE_OVERLOAD (Instruction ) \
961
1043
template <__spv::GroupOperation Op, typename Group, typename T> \
@@ -1037,6 +1119,24 @@ ControlBarrier(Group, memory_scope FenceScope, memory_order Order) {
1037
1119
} \
1038
1120
return tmp; \
1039
1121
} \
1122
+ } \
1123
+ template <__spv::GroupOperation Op, typename Group, typename T> \
1124
+ inline typename std::enable_if_t < \
1125
+ is_tangle_or_opportunistic_group<Group>::value, T> \
1126
+ Group##Instruction(Group, T x) { \
1127
+ using ConvertedT = detail::ConvertToOpenCLType_t<T>; \
1128
+ \
1129
+ using OCLT = std::conditional_t < \
1130
+ std::is_same<ConvertedT, cl_char>() || \
1131
+ std::is_same<ConvertedT, cl_short>(), \
1132
+ cl_int, \
1133
+ std::conditional_t <std::is_same<ConvertedT, cl_uchar>() || \
1134
+ std::is_same<ConvertedT, cl_ushort>(), \
1135
+ cl_uint, ConvertedT>>; \
1136
+ OCLT Arg = x; \
1137
+ OCLT Ret = __spirv_GroupNonUniform##Instruction ( \
1138
+ group_scope<Group>::value, static_cast <unsigned int >(Op), Arg); \
1139
+ return Ret; \
1040
1140
}
1041
1141
1042
1142
__SYCL_GROUP_COLLECTIVE_OVERLOAD (SMin)
0 commit comments