@@ -290,6 +290,188 @@ AtomicMax(multi_ptr<T, AddressSpace> MPtr, intel::memory_scope Scope,
290
290
return __spirv_AtomicMax (Ptr, SPIRVScope, SPIRVOrder, Value);
291
291
}
292
292
293
+ // Native shuffles map directly to a SPIR-V SubgroupShuffle intrinsic
294
+ template <typename T>
295
+ using EnableIfNativeShuffle =
296
+ detail::enable_if_t <detail::is_arithmetic<T>::value, T>;
297
+
298
+ template <typename T>
299
+ EnableIfNativeShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
300
+ using OCLT = detail::ConvertToOpenCLType_t<T>;
301
+ return __spirv_SubgroupShuffleINTEL (OCLT (x),
302
+ static_cast <uint32_t >(local_id.get (0 )));
303
+ }
304
+
305
+ template <typename T>
306
+ EnableIfNativeShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id) {
307
+ using OCLT = detail::ConvertToOpenCLType_t<T>;
308
+ return __spirv_SubgroupShuffleXorINTEL (
309
+ OCLT (x), static_cast <uint32_t >(local_id.get (0 )));
310
+ }
311
+
312
+ template <typename T>
313
+ EnableIfNativeShuffle<T> SubgroupShuffleDown (T x, T y, id<1 > local_id) {
314
+ using OCLT = detail::ConvertToOpenCLType_t<T>;
315
+ return __spirv_SubgroupShuffleDownINTEL (
316
+ OCLT (x), OCLT (y), static_cast <uint32_t >(local_id.get (0 )));
317
+ }
318
+
319
+ template <typename T>
320
+ EnableIfNativeShuffle<T> SubgroupShuffleUp (T x, T y, id<1 > local_id) {
321
+ using OCLT = detail::ConvertToOpenCLType_t<T>;
322
+ return __spirv_SubgroupShuffleUpINTEL (OCLT (x), OCLT (y),
323
+ static_cast <uint32_t >(local_id.get (0 )));
324
+ }
325
+
326
+ // Bitcast shuffles can be implemented using a single SPIR-V SubgroupShuffle
327
+ // intrinsic, but require type-punning via an appropriate integer type
328
+ template <typename T>
329
+ using EnableIfBitcastShuffle =
330
+ detail::enable_if_t <!detail::is_arithmetic<T>::value &&
331
+ (std::is_trivially_copyable<T>::value &&
332
+ (sizeof (T) == 1 || sizeof (T) == 2 ||
333
+ sizeof (T) == 4 || sizeof (T) == 8 )),
334
+ T>;
335
+
336
+ template <typename T>
337
+ using ConvertToNativeShuffleType_t = select_cl_scalar_integral_unsigned_t <T>;
338
+
339
+ template <typename T>
340
+ EnableIfBitcastShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
341
+ using ShuffleT = ConvertToNativeShuffleType_t<T>;
342
+ auto ShuffleX = detail::bit_cast<ShuffleT>(x);
343
+ ShuffleT Result = __spirv_SubgroupShuffleINTEL (
344
+ ShuffleX, static_cast <uint32_t >(local_id.get (0 )));
345
+ return detail::bit_cast<T>(Result);
346
+ }
347
+
348
+ template <typename T>
349
+ EnableIfBitcastShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id) {
350
+ using ShuffleT = ConvertToNativeShuffleType_t<T>;
351
+ auto ShuffleX = detail::bit_cast<ShuffleT>(x);
352
+ ShuffleT Result = __spirv_SubgroupShuffleXorINTEL (
353
+ ShuffleX, static_cast <uint32_t >(local_id.get (0 )));
354
+ return detail::bit_cast<T>(Result);
355
+ }
356
+
357
+ template <typename T>
358
+ EnableIfBitcastShuffle<T> SubgroupShuffleDown (T x, T y, id<1 > local_id) {
359
+ using ShuffleT = ConvertToNativeShuffleType_t<T>;
360
+ auto ShuffleX = detail::bit_cast<ShuffleT>(x);
361
+ auto ShuffleY = detail::bit_cast<ShuffleT>(y);
362
+ ShuffleT Result = __spirv_SubgroupShuffleDownINTEL (
363
+ ShuffleX, ShuffleY, static_cast <uint32_t >(local_id.get (0 )));
364
+ return detail::bit_cast<T>(Result);
365
+ }
366
+
367
+ template <typename T>
368
+ EnableIfBitcastShuffle<T> SubgroupShuffleUp (T x, T y, id<1 > local_id) {
369
+ using ShuffleT = ConvertToNativeShuffleType_t<T>;
370
+ auto ShuffleX = detail::bit_cast<ShuffleT>(x);
371
+ auto ShuffleY = detail::bit_cast<ShuffleT>(y);
372
+ ShuffleT Result = __spirv_SubgroupShuffleUpINTEL (
373
+ ShuffleX, ShuffleY, static_cast <uint32_t >(local_id.get (0 )));
374
+ return detail::bit_cast<T>(Result);
375
+ }
376
+
377
+ // Generic shuffles may require multiple calls to SPIR-V SubgroupShuffle
378
+ // intrinsics, and should use the fewest shuffles possible:
379
+ // - Loop over 64-bit chunks until remaining bytes < 64-bit
380
+ // - At most one 32-bit, 16-bit and 8-bit chunk left over
381
+ template <typename T>
382
+ using EnableIfGenericShuffle =
383
+ detail::enable_if_t <!detail::is_arithmetic<T>::value &&
384
+ !(std::is_trivially_copyable<T>::value &&
385
+ (sizeof (T) == 1 || sizeof (T) == 2 ||
386
+ sizeof (T) == 4 || sizeof (T) == 8 )),
387
+ T>;
388
+
389
+ template <typename T, typename ShuffleFunctor>
390
+ void GenericShuffle (const ShuffleFunctor &ShuffleBytes) {
391
+ if (sizeof (T) >= sizeof (uint64_t )) {
392
+ #pragma unroll
393
+ for (size_t Offset = 0 ; Offset < sizeof (T); Offset += sizeof (uint64_t )) {
394
+ ShuffleBytes (Offset, sizeof (uint64_t ));
395
+ }
396
+ }
397
+ if (sizeof (T) % sizeof (uint64_t ) >= sizeof (uint32_t )) {
398
+ size_t Offset = sizeof (T) / sizeof (uint64_t ) * sizeof (uint64_t );
399
+ ShuffleBytes (Offset, sizeof (uint32_t ));
400
+ }
401
+ if (sizeof (T) % sizeof (uint32_t ) >= sizeof (uint16_t )) {
402
+ size_t Offset = sizeof (T) / sizeof (uint32_t ) * sizeof (uint32_t );
403
+ ShuffleBytes (Offset, sizeof (uint16_t ));
404
+ }
405
+ if (sizeof (T) % sizeof (uint16_t ) >= sizeof (uint8_t )) {
406
+ size_t Offset = sizeof (T) / sizeof (uint16_t ) * sizeof (uint16_t );
407
+ ShuffleBytes (Offset, sizeof (uint8_t ));
408
+ }
409
+ }
410
+
411
+ template <typename T>
412
+ EnableIfGenericShuffle<T> SubgroupShuffle (T x, id<1 > local_id) {
413
+ T Result;
414
+ char *XBytes = reinterpret_cast <char *>(&x);
415
+ char *ResultBytes = reinterpret_cast <char *>(&Result);
416
+ auto ShuffleBytes = [=](size_t Offset, size_t Size) {
417
+ uint64_t ShuffleX, ShuffleResult;
418
+ detail::memcpy (&ShuffleX, XBytes + Offset, Size);
419
+ ShuffleResult = SubgroupShuffle (ShuffleX, local_id);
420
+ detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
421
+ };
422
+ GenericShuffle<T>(ShuffleBytes);
423
+ return Result;
424
+ }
425
+
426
+ template <typename T>
427
+ EnableIfGenericShuffle<T> SubgroupShuffleXor (T x, id<1 > local_id) {
428
+ T Result;
429
+ char *XBytes = reinterpret_cast <char *>(&x);
430
+ char *ResultBytes = reinterpret_cast <char *>(&Result);
431
+ auto ShuffleBytes = [=](size_t Offset, size_t Size) {
432
+ uint64_t ShuffleX, ShuffleResult;
433
+ detail::memcpy (&ShuffleX, XBytes + Offset, Size);
434
+ ShuffleResult = SubgroupShuffleXor (ShuffleX, local_id);
435
+ detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
436
+ };
437
+ GenericShuffle<T>(ShuffleBytes);
438
+ return Result;
439
+ }
440
+
441
+ template <typename T>
442
+ EnableIfGenericShuffle<T> SubgroupShuffleDown (T x, T y, id<1 > local_id) {
443
+ T Result;
444
+ char *XBytes = reinterpret_cast <char *>(&x);
445
+ char *YBytes = reinterpret_cast <char *>(&y);
446
+ char *ResultBytes = reinterpret_cast <char *>(&Result);
447
+ auto ShuffleBytes = [=](size_t Offset, size_t Size) {
448
+ uint64_t ShuffleX, ShuffleY, ShuffleResult;
449
+ detail::memcpy (&ShuffleX, XBytes + Offset, Size);
450
+ detail::memcpy (&ShuffleY, YBytes + Offset, Size);
451
+ ShuffleResult = SubgroupShuffleDown (ShuffleX, ShuffleY, local_id);
452
+ detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
453
+ };
454
+ GenericShuffle<T>(ShuffleBytes);
455
+ return Result;
456
+ }
457
+
458
+ template <typename T>
459
+ EnableIfGenericShuffle<T> SubgroupShuffleUp (T x, T y, id<1 > local_id) {
460
+ T Result;
461
+ char *XBytes = reinterpret_cast <char *>(&x);
462
+ char *YBytes = reinterpret_cast <char *>(&y);
463
+ char *ResultBytes = reinterpret_cast <char *>(&Result);
464
+ auto ShuffleBytes = [=](size_t Offset, size_t Size) {
465
+ uint64_t ShuffleX, ShuffleY, ShuffleResult;
466
+ detail::memcpy (&ShuffleX, XBytes + Offset, Size);
467
+ detail::memcpy (&ShuffleY, YBytes + Offset, Size);
468
+ ShuffleResult = SubgroupShuffleUp (ShuffleX, ShuffleY, local_id);
469
+ detail::memcpy (ResultBytes + Offset, &ShuffleResult, Size);
470
+ };
471
+ GenericShuffle<T>(ShuffleBytes);
472
+ return Result;
473
+ }
474
+
293
475
} // namespace spirv
294
476
} // namespace detail
295
477
} // namespace sycl
0 commit comments