Skip to content

Commit d04ebb0

Browse files
authored
[SYCL][ESIMD] Support 64-bit offsets for stateless accessor gather/scatter (#9462)
Signed-off-by: Sarnie, Nick <[email protected]>
1 parent 3028d82 commit d04ebb0

File tree

2 files changed

+102
-12
lines changed

2 files changed

+102
-12
lines changed

sycl/include/sycl/ext/intel/esimd/memory.hpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ gather_impl(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset,
521521
/// \c float.
522522
/// @tparam N The number of vector elements. Can be \c 1, \c 8, \c 16 or \c 32.
523523
/// @tparam AccessorTy The accessor type.
524+
/// @tparam Toffset The offset type.
524525
/// @param acc The accessor to gather from.
525526
/// @param offsets Per-element offsets in bytes.
526527
/// @param glob_offset Offset in bytes added to each individual element's offset
@@ -529,12 +530,17 @@ gather_impl(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset,
529530
/// predicate are not accessed, their values in the resulting vector are
530531
/// undefined.
531532
///
532-
template <typename T, int N, typename AccessorTy>
533-
__ESIMD_API std::enable_if_t<(sizeof(T) <= 4) &&
534-
(N == 1 || N == 8 || N == 16 || N == 32) &&
535-
!std::is_pointer<AccessorTy>::value,
536-
simd<T, N>>
537-
gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
533+
template <typename T, int N, typename AccessorTy, typename Toffset>
534+
__ESIMD_API std::enable_if_t<
535+
(sizeof(T) <= 4) && (N == 1 || N == 8 || N == 16 || N == 32) &&
536+
!std::is_pointer<AccessorTy>::value && std::is_integral_v<Toffset>,
537+
simd<T, N>>
538+
gather(AccessorTy acc, simd<Toffset, N> offsets,
539+
#ifdef __ESIMD_FORCE_STATELESS_MEM
540+
uint64_t glob_offset = 0,
541+
#else
542+
uint32_t glob_offset = 0,
543+
#endif
538544
simd_mask<N> mask = 1) {
539545
#ifdef __ESIMD_FORCE_STATELESS_MEM
540546
return gather<T, N>(__ESIMD_DNS::accessorToPointer<T>(acc, glob_offset),
@@ -554,6 +560,7 @@ gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
554560
/// \c float.
555561
/// @tparam N The number of vector elements. Can be \c 1, \c 8, \c 16 or \c 32.
556562
/// @tparam AccessorTy The accessor type.
563+
/// @tparam Toffset The offset type.
557564
/// @param acc The accessor to scatter to.
558565
/// @param offsets Per-element offsets in bytes.
559566
/// @param vals Values to write.
@@ -563,12 +570,17 @@ gather(AccessorTy acc, simd<uint32_t, N> offsets, uint32_t glob_offset = 0,
563570
/// predicate are not accessed.
564571
///
565572
///
566-
template <typename T, int N, typename AccessorTy>
567-
__ESIMD_API std::enable_if_t<(sizeof(T) <= 4) &&
568-
(N == 1 || N == 8 || N == 16 || N == 32) &&
569-
!std::is_pointer<AccessorTy>::value>
570-
scatter(AccessorTy acc, simd<uint32_t, N> offsets, simd<T, N> vals,
571-
uint32_t glob_offset = 0, simd_mask<N> mask = 1) {
573+
template <typename T, int N, typename AccessorTy, typename Toffset>
574+
__ESIMD_API std::enable_if_t<
575+
(sizeof(T) <= 4) && (N == 1 || N == 8 || N == 16 || N == 32) &&
576+
!std::is_pointer<AccessorTy>::value && std::is_integral_v<Toffset>>
577+
scatter(AccessorTy acc, simd<Toffset, N> offsets, simd<T, N> vals,
578+
#ifdef __ESIMD_FORCE_STATELESS_MEM
579+
uint64_t glob_offset = 0,
580+
#else
581+
uint32_t glob_offset = 0,
582+
#endif
583+
simd_mask<N> mask = 1) {
572584
#ifdef __ESIMD_FORCE_STATELESS_MEM
573585
scatter<T, N>(__ESIMD_DNS::accessorToPointer<T>(acc, glob_offset), offsets,
574586
vals, mask);
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//=-accessor_gather_scatter_stateless_64.cpp - DPC++ ESIMD on-device test-=//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//=-------------------------------------------------------------------=//
8+
// REQUIRES: gpu-intel-pvc
9+
// RUN: %{build} -o %t.out -fsycl-esimd-force-stateless-mem
10+
// RUN: %{run} %t.out
11+
12+
#include <iostream>
13+
#include <sycl/ext/intel/esimd.hpp>
14+
#include <sycl/sycl.hpp>
15+
16+
using namespace sycl;
17+
using namespace sycl::ext::intel::esimd;
18+
19+
int main(void) {
20+
constexpr unsigned VL = 16;
21+
constexpr uint64_t Gig = 1024 * 1024 * 1024;
22+
23+
constexpr uint64_t Size = Gig / 2 + 16;
24+
uint64_t *A = new uint64_t[Size];
25+
26+
for (uint64_t i = 0; i < Size; ++i)
27+
A[i] = i;
28+
29+
buffer<uint64_t, 1> bufa(A, range<1>(Size));
30+
queue q;
31+
32+
auto dev = q.get_device();
33+
std::cout << "Running on " << dev.get_info<info::device::name>() << "\n";
34+
try {
35+
q.submit([&](handler &cgh) {
36+
auto PA = bufa.get_access<access::mode::read_write>(cgh);
37+
cgh.single_task<class Test>([=]() SYCL_ESIMD_KERNEL {
38+
uint64_t offsetStart = (Size - VL) * sizeof(uint64_t);
39+
simd<uint64_t, VL> offset(offsetStart, sizeof(uint64_t));
40+
simd<uint64_t, VL> beginning(0, sizeof(uint64_t));
41+
simd<uint32_t, VL> va = gather<uint32_t, VL>(PA, beginning);
42+
simd<uint32_t, VL> vb = gather<uint32_t, VL>(PA, offset);
43+
va *= 2;
44+
vb *= 5;
45+
scatter<uint32_t, VL>(PA, beginning, va);
46+
scatter<uint32_t, VL>(PA, offset, vb);
47+
});
48+
}).wait();
49+
} catch (sycl::exception const &e) {
50+
std::cout << "SYCL exception caught: " << e.what() << '\n';
51+
delete[] A;
52+
return 1;
53+
}
54+
55+
bool failed = false;
56+
host_accessor A_acc(bufa);
57+
for (uint64_t I = 0; I < VL; I++) {
58+
uint64_t Expected = static_cast<uint32_t>(I) * 2;
59+
if (A_acc[I] != Expected) {
60+
std::cout << "FAILED: " << A_acc[I] << " != " << Expected << std::endl;
61+
failed = true;
62+
}
63+
}
64+
for (uint64_t I = Size - VL; I < Size; I++) {
65+
uint64_t Expected = static_cast<uint32_t>(I) * 5;
66+
if (A_acc[I] != Expected) {
67+
std::cout << "FAILED: " << A_acc[I] << " != " << Expected << std::endl;
68+
failed = true;
69+
}
70+
}
71+
72+
if (!failed)
73+
std::cout << "PASSED" << std::endl;
74+
75+
delete[] A;
76+
77+
return failed;
78+
}

0 commit comments

Comments
 (0)