Skip to content

Commit 1060b20

Browse files
authored
[SYCL][ESIMD] Support sycl::address_space_cast (#11972)
It works fine. Signed-off-by: Sarnie, Nick <[email protected]>
1 parent a57a96c commit 1060b20

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ static const char *LegalSYCLFunctions[] = {
3636
"^sycl::_V1::accessor<.+>::~accessor",
3737
"^sycl::_V1::accessor<.+>::getQualifiedPtr",
3838
"^sycl::_V1::accessor<.+>::__init_esimd",
39+
"^sycl::_V1::address_space_cast",
3940
"^sycl::_V1::local_accessor<.+>::local_accessor",
4041
"^sycl::_V1::local_accessor<.+>::__init_esimd",
4142
"^sycl::_V1::local_accessor<.+>::get_pointer",
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// The tests a basic E2E invoke_simd test checking that
2+
// sycl::address_space_cast works.
3+
4+
// RUN: %{build} -fno-sycl-device-code-split-esimd -Xclang -fsycl-allow-func-ptr -o %t.out
5+
// RUN: env IGC_VCSaveStackCallLinkage=1 IGC_VCDirectCallsOnly=1 %{run} %t.out
6+
7+
#include <sycl/ext/intel/esimd.hpp>
8+
#include <sycl/ext/oneapi/experimental/invoke_simd.hpp>
9+
#include <sycl/ext/oneapi/experimental/uniform.hpp>
10+
#include <sycl/sycl.hpp>
11+
12+
#include <functional>
13+
#include <iostream>
14+
#include <type_traits>
15+
16+
#include "../../ESIMD/esimd_test_utils.hpp"
17+
18+
using namespace esimd_test;
19+
using namespace sycl::ext::oneapi::experimental;
20+
using namespace sycl;
21+
namespace esimd = sycl::ext::intel::esimd;
22+
23+
constexpr int VL = 16;
24+
25+
__attribute__((always_inline)) esimd::simd<float, VL>
26+
ESIMD_CALLEE(float *A, esimd::simd<float, VL> b, int i) SYCL_ESIMD_FUNCTION {
27+
esimd::simd<float, VL> a;
28+
global_ptr<float, access::decorated::yes> ptr =
29+
sycl::address_space_cast<access::address_space::global_space,
30+
access::decorated::yes, float>(A);
31+
a.copy_from(ptr + i);
32+
return a + b;
33+
}
34+
35+
[[intel::device_indirectly_callable]] SYCL_EXTERNAL
36+
simd<float, VL> __regcall SIMD_CALLEE(float *A, simd<float, VL> b,
37+
int i) SYCL_ESIMD_FUNCTION {
38+
esimd::simd<float, VL> res = ESIMD_CALLEE(A, b, i);
39+
return res;
40+
}
41+
42+
bool test() {
43+
constexpr unsigned Size = 1024;
44+
constexpr unsigned GroupSize = 4 * VL;
45+
46+
queue q(ESIMDSelector, createExceptionHandler());
47+
48+
printTestLabel(q);
49+
50+
float *A = malloc_shared<float>(Size, q);
51+
float *B = malloc_shared<float>(Size, q);
52+
float *C = malloc_shared<float>(Size, q);
53+
54+
for (unsigned i = 0; i < Size; ++i) {
55+
A[i] = B[i] = i;
56+
C[i] = -1;
57+
}
58+
59+
sycl::range<1> GlobalRange{Size};
60+
// Number of workitems in each workgroup.
61+
sycl::range<1> LocalRange{GroupSize};
62+
63+
sycl::nd_range<1> Range(GlobalRange, LocalRange);
64+
65+
try {
66+
auto e = q.submit([&](handler &cgh) {
67+
cgh.parallel_for(
68+
Range, [=](nd_item<1> ndi) [[intel::reqd_sub_group_size(VL)]] {
69+
sub_group sg = ndi.get_sub_group();
70+
group<1> g = ndi.get_group();
71+
uint32_t i =
72+
sg.get_group_linear_id() * VL + g.get_linear_id() * GroupSize;
73+
uint32_t wi_id = i + sg.get_local_id();
74+
float res = 0;
75+
76+
res =
77+
invoke_simd(sg, SIMD_CALLEE, uniform{A}, B[wi_id], uniform{i});
78+
79+
C[wi_id] = res;
80+
});
81+
});
82+
e.wait();
83+
} catch (sycl::exception const &e) {
84+
std::cout << "SYCL exception caught: " << e.what() << '\n';
85+
sycl::free(A, q);
86+
sycl::free(B, q);
87+
sycl::free(C, q);
88+
return false;
89+
}
90+
91+
int err_cnt = 0;
92+
93+
for (unsigned i = 0; i < Size; ++i) {
94+
if ((A[i] + B[i]) != C[i]) {
95+
if (++err_cnt < 10) {
96+
std::cout << "failed at index " << i << ", " << C[i] << " != 3*("
97+
<< A[i] << " + " << B[i] << ")\n";
98+
}
99+
}
100+
}
101+
if (err_cnt > 0) {
102+
std::cout << " pass rate: "
103+
<< ((float)(Size - err_cnt) / (float)Size) * 100.0f << "% ("
104+
<< (Size - err_cnt) << "/" << Size << ")\n";
105+
}
106+
sycl::free(A, q);
107+
sycl::free(B, q);
108+
sycl::free(C, q);
109+
110+
std::cout << (err_cnt > 0 ? "FAILED\n" : "Passed\n");
111+
return err_cnt == 0;
112+
}
113+
114+
int main() {
115+
bool Passed = true;
116+
Passed &= test();
117+
return Passed ? 0 : 1;
118+
}

0 commit comments

Comments
 (0)