@@ -182,20 +182,49 @@ DefineBitwiseGroupOp(uint64_t, int64_t, i64)
182
182
return Type (); /* todo: add support for other flags as they are tested*/ \
183
183
}
184
184
185
- #define DefineBroadCast (Type, Sfx, MuxType )\
186
- DefineBroadCastImpl (Type, Sfx, MuxType, uint32_t )
185
+ #define DefineBroadcastMuxType (Type, Sfx, MuxType, IDType ) \
186
+ DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
187
+ int32_t id, MuxType val, uint64_t lidx, uint64_t lidy, uint64_t lidz); \
188
+ DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \
189
+ int32_t sg_lid);
190
+
191
+ #define DefineBroadCastImpl (Type, Sfx, MuxType, IDType ) \
192
+ DEVICE_EXTERNAL Type __spirv_GroupBroadcast (uint32_t g, Type v, \
193
+ IDType l) { \
194
+ if (__spv::Scope::Flag::Subgroup == g) \
195
+ return __mux_sub_group_broadcast_##Sfx (v, l); \
196
+ else \
197
+ return __mux_work_group_broadcast_##Sfx (0 , v, l, 0 , 0 ); \
198
+ } \
199
+ \
200
+ DEVICE_EXTERNAL Type __spirv_GroupBroadcast (uint32_t g, Type v, \
201
+ sycl::vec<IDType, 2 >::vector_t l) { \
202
+ if (__spv::Scope::Flag::Subgroup == g) \
203
+ return __mux_sub_group_broadcast_##Sfx (v, l[0 ]); \
204
+ else \
205
+ return __mux_work_group_broadcast_##Sfx (0 , v, l[0 ], l[0 ], 0 ); \
206
+ } \
207
+ \
208
+ DEVICE_EXTERNAL Type __spirv_GroupBroadcast (uint32_t g, Type v, \
209
+ sycl::vec<IDType, 3 >::vector_t l) { \
210
+ if (__spv::Scope::Flag::Subgroup == g) \
211
+ return __mux_sub_group_broadcast_##Sfx (v, l[0 ]); \
212
+ else \
213
+ return __mux_work_group_broadcast_##Sfx (0 , v, l[0 ], l[1 ], l[2 ]); \
214
+ } \
215
+
216
+ #define DefineBroadCast (Type, Sfx, MuxType ) \
217
+ DefineBroadcastMuxType (Type, Sfx, MuxType, uint32_t ) \
218
+ DefineBroadcastMuxType(Type, Sfx, MuxType, uint64_t ) \
219
+ DefineBroadCastImpl(Type, Sfx, MuxType, uint32_t ) \
220
+ DefineBroadCastImpl(Type, Sfx, MuxType, uint64_t ) \
187
221
188
- DefineBroadCast(int64_t , i64 , int64_t )
189
- DefineBroadCast(uint64_t , i64 , int64_t )
190
- DefineBroadCast(int32_t , i32 , int32_t )
191
222
DefineBroadCast(uint32_t , i32 , int32_t )
223
+ DefineBroadCast(int32_t , i32 , int32_t )
192
224
DefineBroadCast(float , f32 , float )
193
225
DefineBroadCast(double , f64 , double )
194
-
195
- DefineBroadCastImpl(int32_t , i32 , int32_t , uint64_t )
196
- DefineBroadCastImpl(float , f32 , float , uint64_t )
197
- DefineBroadCastImpl(double , f64 , double , uint64_t )
198
- DefineBroadCastImpl(uint64_t , i64 , int64_t , uint64_t )
226
+ DefineBroadCast(uint64_t , i64 , int64_t )
227
+ DefineBroadCast(int64_t , i64 , int64_t )
199
228
200
229
201
230
#define DefShuffleINTEL (Type, Sfx, MuxType ) \
@@ -248,6 +277,8 @@ DefShuffleINTEL_All(int32_t, i32, int32_t)
248
277
DefShuffleINTEL_All(uint32_t , i32 , int32_t )
249
278
DefShuffleINTEL_All(int16_t , i16 , int16_t )
250
279
DefShuffleINTEL_All(uint16_t , i16 , int16_t )
280
+ DefShuffleINTEL_All(int8_t , i8 , int8_t )
281
+ DefShuffleINTEL_All(uint8_t , i8 , int8_t )
251
282
DefShuffleINTEL_All(double , f64 , double )
252
283
DefShuffleINTEL_All(float , f32 , float )
253
284
0 commit comments