Skip to content

Commit 5245c75

Browse files
authored
[SYCL][Matrix tests]make multi_ptr access direct inside the kernel (#12351)
This change explicitly defines multi_ptr inside the kernel instead of outside the kernel. Before, we rely on the runtime to correctly treat the lambda capture clause and create correct private copies for each of these variables. With this change, we are making sure that these variables are correctly captured (private copies) inside the kernel.
1 parent 6b1021c commit 5245c75

5 files changed

+48
-39
lines changed

sycl/test-e2e/Matrix/joint_matrix_annotated_ptr_impl.hpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@ template <typename T1, typename T2, size_t M, size_t N, size_t K,
66
void matrix_multiply(T1 *C, T2 *A, T2 *B, queue &q) {
77
size_t NDRangeM = M / TM;
88
size_t NDRangeN = N / TN;
9-
auto pA = address_space_cast<sycl::access::address_space::global_space,
10-
sycl::access::decorated::no>(A);
11-
auto pB = address_space_cast<sycl::access::address_space::global_space,
12-
sycl::access::decorated::no>(B);
13-
auto pC = address_space_cast<sycl::access::address_space::global_space,
14-
sycl::access::decorated::no>(C);
159
q.submit([&](handler &cgh) {
1610
cgh.parallel_for(
1711
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
1812
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
1913

2014
{
15+
auto pA =
16+
address_space_cast<sycl::access::address_space::global_space,
17+
sycl::access::decorated::no>(A);
18+
auto pB =
19+
address_space_cast<sycl::access::address_space::global_space,
20+
sycl::access::decorated::no>(B);
21+
auto pC =
22+
address_space_cast<sycl::access::address_space::global_space,
23+
sycl::access::decorated::no>(C);
2124
const auto global_idx = spmd_item.get_global_id(0);
2225
const auto global_idy = spmd_item.get_global_id(1);
2326
const auto sg_startx = global_idx - spmd_item.get_local_id(0);

sycl/test-e2e/Matrix/joint_matrix_bf16_fill_k_cache_impl.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,6 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
7676
assert(rowsA % tM == 0);
7777
assert(colsA % tK == 0);
7878
assert(colsB % tN == 0);
79-
80-
auto pA = address_space_cast<sycl::access::address_space::global_space,
81-
sycl::access::decorated::no>(A);
82-
auto pB = address_space_cast<sycl::access::address_space::global_space,
83-
sycl::access::decorated::no>(B);
84-
auto pC = address_space_cast<sycl::access::address_space::global_space,
85-
sycl::access::decorated::no>(C);
86-
8779
// submit main kernel
8880
std::chrono::high_resolution_clock::time_point start =
8981
std::chrono::high_resolution_clock::now();
@@ -94,6 +86,15 @@ double joint_matmul(TOperand *A, TOperand *B, TResult *C, queue &q, int i) {
9486
// loop global
9587
// loop localrange
9688
[=](nd_item<2> it) [[intel::reqd_sub_group_size(sgSize)]] {
89+
auto pA =
90+
address_space_cast<sycl::access::address_space::global_space,
91+
sycl::access::decorated::no>(A);
92+
auto pB =
93+
address_space_cast<sycl::access::address_space::global_space,
94+
sycl::access::decorated::no>(B);
95+
auto pC =
96+
address_space_cast<sycl::access::address_space::global_space,
97+
sycl::access::decorated::no>(C);
9798
auto m2 = it.get_group(0);
9899
auto n2 = it.get_group(1);
99100
auto m1 = it.get_local_id(0);

sycl/test-e2e/Matrix/joint_matrix_colA_rowB_colC_impl.hpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,22 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q) {
1919
size_t NDRangeM = M / TM;
2020
size_t NDRangeN = N / TN;
2121

22-
auto pA = address_space_cast<sycl::access::address_space::global_space,
23-
sycl::access::decorated::no>(A);
24-
auto pB = address_space_cast<sycl::access::address_space::global_space,
25-
sycl::access::decorated::no>(B);
26-
auto pC = address_space_cast<sycl::access::address_space::global_space,
27-
sycl::access::decorated::no>(C);
28-
2922
q.submit([&](handler &cgh) {
3023
cgh.parallel_for(
3124
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
3225
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
3326

3427
{
28+
auto pA =
29+
address_space_cast<sycl::access::address_space::global_space,
30+
sycl::access::decorated::no>(A);
31+
auto pB =
32+
address_space_cast<sycl::access::address_space::global_space,
33+
sycl::access::decorated::no>(B);
34+
auto pC =
35+
address_space_cast<sycl::access::address_space::global_space,
36+
sycl::access::decorated::no>(C);
37+
3538
// The submatrix API has to be accessed by all the workitems in a
3639
// subgroup these functions will be called once by the subgroup no
3740
// code divergence between the workitems

sycl/test-e2e/Matrix/joint_matrix_out_bounds_impl.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,21 @@ void matrix_multiply(T1 *C, T2 *A, T2 *B, queue q, unsigned int vnniFactor) {
1818
// Add one iteration for the out of bounds dpas instruction
1919
size_t NDRangeM = M / TM + (((M % TM) != 0) ? 1 : 0);
2020
size_t NDRangeN = N / TN;
21-
22-
auto pA = address_space_cast<sycl::access::address_space::global_space,
23-
sycl::access::decorated::no>(A);
24-
auto pB = address_space_cast<sycl::access::address_space::global_space,
25-
sycl::access::decorated::no>(B);
26-
auto pC = address_space_cast<sycl::access::address_space::global_space,
27-
sycl::access::decorated::no>(C);
28-
2921
q.submit([&](handler &cgh) {
3022
cgh.parallel_for(
3123
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
3224
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]]
3325

3426
{
27+
auto pA =
28+
address_space_cast<sycl::access::address_space::global_space,
29+
sycl::access::decorated::no>(A);
30+
auto pB =
31+
address_space_cast<sycl::access::address_space::global_space,
32+
sycl::access::decorated::no>(B);
33+
auto pC =
34+
address_space_cast<sycl::access::address_space::global_space,
35+
sycl::access::decorated::no>(C);
3536
// The submatrix API has to be accessed by all the workitems in a
3637
// subgroup these functions will be called once by the subgroup no
3738
// code divergence between the workitems

sycl/test-e2e/Matrix/joint_matrix_transposeC_impl.hpp

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,21 @@ void matrix_load_and_store(T1 *input, T1 *out_col_major, T1 *out_row_major,
1313
size_t NDRangeM = M / TM;
1414
size_t NDRangeN = N / TN;
1515

16-
auto p_input = address_space_cast<sycl::access::address_space::global_space,
17-
sycl::access::decorated::no>(input);
18-
19-
auto p_out_col_major =
20-
address_space_cast<sycl::access::address_space::global_space,
21-
sycl::access::decorated::no>(out_col_major);
22-
auto p_out_row_major =
23-
address_space_cast<sycl::access::address_space::global_space,
24-
sycl::access::decorated::no>(out_row_major);
25-
2616
q.submit([&](handler &cgh) {
2717
cgh.parallel_for(
2818
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
2919
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
20+
auto p_input =
21+
address_space_cast<sycl::access::address_space::global_space,
22+
sycl::access::decorated::no>(input);
23+
24+
auto p_out_col_major =
25+
address_space_cast<sycl::access::address_space::global_space,
26+
sycl::access::decorated::no>(out_col_major);
27+
auto p_out_row_major =
28+
address_space_cast<sycl::access::address_space::global_space,
29+
sycl::access::decorated::no>(out_row_major);
30+
3031
const auto global_idx = spmd_item.get_global_id(0);
3132
const auto global_idy = spmd_item.get_global_id(1);
3233
const auto sg_startx = global_idx - spmd_item.get_local_id(0);

0 commit comments

Comments
 (0)