Skip to content

Commit e49a0c3

Browse files
authored
[SYCL][Matrix]Align the parameter Group of APIs {joint_matrix_store/joint_matrix_load} with matrix spec (#12041)
[SYCL][Matrix]Align the parameter Group of APIs {joint_matrix_store/joint_matrix_load} with matrix spec --------- Signed-off-by: Ni, Wenhui <[email protected]>
1 parent 88bdd76 commit e49a0c3

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ template <
146146
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value, bool> =
147147
true>
148148
inline __SYCL_ALWAYS_INLINE void joint_matrix_load(
149-
Group &sg,
149+
Group sg,
150150
joint_matrix<Group, S, use::accumulator, NumRows, NumCols,
151151
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &res,
152152
multi_ptr<T, Space, IsDecorated> src, size_t stride,
@@ -214,7 +214,7 @@ template <
214214
std::is_same<std::remove_const_t<T>, float>::value),
215215
bool> = true>
216216
inline __SYCL_ALWAYS_INLINE void
217-
joint_matrix_load(Group &sg,
217+
joint_matrix_load(Group sg,
218218
joint_matrix<Group, S, Use, NumRows, NumCols, Layout> &res,
219219
multi_ptr<T, Space, IsDecorated> src, size_t stride) {
220220
#if defined(__SYCL_DEVICE_ONLY__)
@@ -253,7 +253,7 @@ joint_matrix_load(Group &sg,
253253
template <typename Group, typename T, size_t NumRows, size_t NumCols,
254254
access::address_space Space, access::decorated IsDecorated>
255255
inline __SYCL_ALWAYS_INLINE void joint_matrix_store(
256-
Group &sg,
256+
Group sg,
257257
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols,
258258
sycl::ext::oneapi::experimental::matrix::layout::dynamic>
259259
&src,

sycl/test/matrix/matrix-bfloat16-test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
5959
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
6060
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
6161

62-
sycl::sub_group sg = spmd_item.get_sub_group();
6362
joint_matrix<sycl::sub_group, bfloat16, use::a, TM, TK,
6463
layout::row_major>
6564
sub_a;
@@ -73,26 +72,27 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
7372
joint_matrix<sycl::sub_group, float, use::accumulator, TM, TN> sub_c;
7473

7574
joint_matrix_load(
76-
sg, sub_c,
75+
spmd_item.get_sub_group(), sub_c,
7776
accC.template get_multi_ptr<sycl::access::decorated::no>() +
7877
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
7978
N, layout::row_major);
8079
for (int k = 0; k < K / TK; k += 1) { //
8180
joint_matrix_load(
82-
sg, sub_a,
81+
spmd_item.get_sub_group(), sub_a,
8382
accA.template get_multi_ptr<sycl::access::decorated::no>() +
8483
(sg_startx * TM) * K + k * TK,
8584
K);
8685
// Assuming B data is already in VNNI format.
8786
joint_matrix_load(
88-
sg, sub_b,
87+
spmd_item.get_sub_group(), sub_b,
8988
accB.template get_multi_ptr<sycl::access::decorated::no>() +
9089
(k * TK / 2) * (N * 2) + sg_starty / SG_SZ * TN * 2,
9190
N * 2);
92-
joint_matrix_mad(sg, sub_c, sub_a, sub_b, sub_c);
91+
joint_matrix_mad(spmd_item.get_sub_group(), sub_c, sub_a, sub_b,
92+
sub_c);
9393
}
9494
joint_matrix_store(
95-
sg, sub_c,
95+
spmd_item.get_sub_group(), sub_c,
9696
accC.template get_multi_ptr<sycl::access::decorated::no>() +
9797
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
9898
N, layout::row_major);

0 commit comments

Comments
 (0)