Skip to content

Commit 2cfdfa4

Browse files
authored
[SYCL][JointMatrix] Added missing required subgroup to VNNI tests (#10565)
Because of missing required sub_group setting the tests were failing. The only change is adding "[[intel::reqd_sub_group_size(SG_SZ)]]". The rest is clang-formatting.
1 parent 57098b0 commit 2cfdfa4

File tree

2 files changed

+37
-39
lines changed

2 files changed

+37
-39
lines changed

sycl/test-e2e/Matrix/Legacy/joint_matrix_int8_vnni_impl.hpp

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -37,42 +37,41 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
3737
cgh.parallel_for<class imatrix>(
3838
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
3939
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
40-
41-
{
42-
// The submatrix API has to be accessed by all the workitems in a
43-
// subgroup these functions will be called once by the subgroup no
44-
// code divergence between the workitems
45-
const auto global_idx = spmd_item.get_global_id(0);
46-
const auto global_idy = spmd_item.get_global_id(1);
47-
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
48-
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
49-
50-
sub_group sg = spmd_item.get_sub_group();
51-
joint_matrix<int8_t, TM, TK> sub_a(sg);
52-
joint_matrix<int8_t, TK, TN, matrix_layout::packed_b> sub_b(sg);
53-
joint_matrix<int32_t, TM, TN> sub_c(sg);
54-
55-
joint_matrix_fill(sg, sub_c, 0);
56-
for (int k = 0; k < K / TK; k += 1) {
57-
joint_matrix_load(
58-
sg, sub_a,
59-
accA.template get_multi_ptr<access::decorated::no>() +
60-
(sg_startx * TM) * K + k * TK,
61-
K, matrix_layout::row_major);
62-
// VNNI transform is done automatically at this level
63-
joint_matrix_load(
64-
sg, sub_b,
65-
accB.template get_multi_ptr<access::decorated::no>() +
66-
(k * TK) * N + sg_starty / SG_SZ * TN,
67-
N, matrix_layout::row_major);
68-
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
69-
}
70-
joint_matrix_store(
71-
sg, sub_c,
72-
accC.template get_multi_ptr<access::decorated::no>() +
73-
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
74-
N, matrix_layout::row_major);
75-
}); // parallel for
40+
[[intel::reqd_sub_group_size(SG_SZ)]] {
41+
// The submatrix API has to be accessed by all the workitems in a
42+
// subgroup these functions will be called once by the subgroup
43+
// no code divergence between the workitems
44+
const auto global_idx = spmd_item.get_global_id(0);
45+
const auto global_idy = spmd_item.get_global_id(1);
46+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
47+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
48+
49+
sub_group sg = spmd_item.get_sub_group();
50+
joint_matrix<int8_t, TM, TK> sub_a(sg);
51+
joint_matrix<int8_t, TK, TN, matrix_layout::packed_b> sub_b(sg);
52+
joint_matrix<int32_t, TM, TN> sub_c(sg);
53+
54+
joint_matrix_fill(sg, sub_c, 0);
55+
for (int k = 0; k < K / TK; k += 1) {
56+
joint_matrix_load(
57+
sg, sub_a,
58+
accA.template get_multi_ptr<access::decorated::no>() +
59+
(sg_startx * TM) * K + k * TK,
60+
K, matrix_layout::row_major);
61+
// VNNI transform is done automatically at this level
62+
joint_matrix_load(
63+
sg, sub_b,
64+
accB.template get_multi_ptr<access::decorated::no>() +
65+
(k * TK) * N + sg_starty / SG_SZ * TN,
66+
N, matrix_layout::row_major);
67+
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
68+
}
69+
joint_matrix_store(
70+
sg, sub_c,
71+
accC.template get_multi_ptr<access::decorated::no>() +
72+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
73+
N, matrix_layout::row_major);
74+
}); // parallel for
7675
}).wait();
7776
}
7877

sycl/test-e2e/Matrix/joint_matrix_int8_vnni_impl.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
3636

3737
cgh.parallel_for<class imatrix>(
3838
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
39-
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
40-
41-
{
39+
[accA, accB, accC, M, N,
40+
K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
4241
// The submatrix API has to be accessed by all the workitems in a
4342
// subgroup these functions will be called once by the subgroup no
4443
// code divergence between the workitems

0 commit comments

Comments
 (0)