Skip to content

Commit de8aefb

Browse files
authored
[SYCL][NATIVECPU] Fix missing declarations for broadcast and shuffle operations (#15140)
Multiple declarations were missing for shuffle and broadcast operations, in particular work group broadcast ones. This adds them.
1 parent 6a02407 commit de8aefb

File tree

1 file changed

+41
-10
lines changed

1 file changed

+41
-10
lines changed

libdevice/nativecpu_utils.cpp

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,20 +182,49 @@ DefineBitwiseGroupOp(uint64_t, int64_t, i64)
182182
return Type(); /*todo: add support for other flags as they are tested*/ \
183183
}
184184

185-
#define DefineBroadCast(Type, Sfx, MuxType)\
186-
DefineBroadCastImpl(Type, Sfx, MuxType, uint32_t)
185+
#define DefineBroadcastMuxType(Type, Sfx, MuxType, IDType) \
186+
DEVICE_EXTERN_C MuxType __mux_work_group_broadcast_##Sfx( \
187+
int32_t id, MuxType val, uint64_t lidx, uint64_t lidy, uint64_t lidz); \
188+
DEVICE_EXTERN_C MuxType __mux_sub_group_broadcast_##Sfx(MuxType val, \
189+
int32_t sg_lid);
190+
191+
#define DefineBroadCastImpl(Type, Sfx, MuxType, IDType) \
192+
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
193+
IDType l) { \
194+
if (__spv::Scope::Flag::Subgroup == g) \
195+
return __mux_sub_group_broadcast_##Sfx(v, l); \
196+
else \
197+
return __mux_work_group_broadcast_##Sfx(0, v, l, 0, 0); \
198+
} \
199+
\
200+
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
201+
sycl::vec<IDType, 2>::vector_t l) { \
202+
if (__spv::Scope::Flag::Subgroup == g) \
203+
return __mux_sub_group_broadcast_##Sfx(v, l[0]); \
204+
else \
205+
return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[0], 0); \
206+
} \
207+
\
208+
DEVICE_EXTERNAL Type __spirv_GroupBroadcast(uint32_t g, Type v, \
209+
sycl::vec<IDType, 3>::vector_t l) { \
210+
if (__spv::Scope::Flag::Subgroup == g) \
211+
return __mux_sub_group_broadcast_##Sfx(v, l[0]); \
212+
else \
213+
return __mux_work_group_broadcast_##Sfx(0, v, l[0], l[1], l[2]); \
214+
} \
215+
216+
#define DefineBroadCast(Type, Sfx, MuxType) \
217+
DefineBroadcastMuxType(Type, Sfx, MuxType, uint32_t) \
218+
DefineBroadcastMuxType(Type, Sfx, MuxType, uint64_t) \
219+
DefineBroadCastImpl(Type, Sfx, MuxType, uint32_t) \
220+
DefineBroadCastImpl(Type, Sfx, MuxType, uint64_t) \
187221

188-
DefineBroadCast(int64_t, i64, int64_t)
189-
DefineBroadCast(uint64_t, i64, int64_t)
190-
DefineBroadCast(int32_t, i32, int32_t)
191222
DefineBroadCast(uint32_t, i32, int32_t)
223+
DefineBroadCast(int32_t, i32, int32_t)
192224
DefineBroadCast(float, f32, float)
193225
DefineBroadCast(double, f64, double)
194-
195-
DefineBroadCastImpl(int32_t, i32, int32_t, uint64_t)
196-
DefineBroadCastImpl(float, f32, float, uint64_t)
197-
DefineBroadCastImpl(double, f64, double, uint64_t)
198-
DefineBroadCastImpl(uint64_t, i64, int64_t, uint64_t)
226+
DefineBroadCast(uint64_t, i64, int64_t)
227+
DefineBroadCast(int64_t, i64, int64_t)
199228

200229

201230
#define DefShuffleINTEL(Type, Sfx, MuxType) \
@@ -248,6 +277,8 @@ DefShuffleINTEL_All(int32_t, i32, int32_t)
248277
DefShuffleINTEL_All(uint32_t, i32, int32_t)
249278
DefShuffleINTEL_All(int16_t, i16, int16_t)
250279
DefShuffleINTEL_All(uint16_t, i16, int16_t)
280+
DefShuffleINTEL_All(int8_t, i8, int8_t)
281+
DefShuffleINTEL_All(uint8_t, i8, int8_t)
251282
DefShuffleINTEL_All(double, f64, double)
252283
DefShuffleINTEL_All(float, f32, float)
253284

0 commit comments

Comments
 (0)