Skip to content

Commit e5079f3

Browse files
committed
simplified code
1 parent 7f21c27 commit e5079f3

File tree

1 file changed

+31
-44
lines changed

1 file changed

+31
-44
lines changed
Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
//==----------- element_wise_abc_impl.hpp - DPC++ joint_matrix-------------
2-
//----==//
1+
//==----------- element_wise_abc_impl.hpp - DPC++ joint_matrix-------------==//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
54
// See https://llvm.org/LICENSE.txt for license information.
@@ -27,36 +26,25 @@ template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
2726
big_matrix(T *data) : mat(data) {}
2827
};
2928

30-
template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
31-
size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
32-
size_t NUM_COLS_C>
33-
void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
34-
big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
35-
big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
36-
size_t M = NUM_ROWS_C;
37-
size_t N = NUM_COLS_C;
38-
size_t K = NUM_COLS_A;
39-
40-
// B => K/4 x N*4, A => M x K, C => M, N
41-
// stride should be X's cols, e.g., B's stirde = N*4
42-
assert(NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4);
43-
29+
template <typename T1, typename T2, size_t M, size_t N, size_t K,
30+
int vnniFactor>
31+
void matrix_elem_wise_ops(big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
32+
big_matrix<T2, K / vnniFactor, N * vnniFactor> &B) {
4433
size_t NDRangeM = M / TM;
4534
size_t NDRangeN = N / TN;
46-
buffer<int8_t, 2> bufA(A.get_data(), range<2>(M, K));
47-
buffer<int8_t, 2> bufB(B.get_data(), range<2>(K, N));
48-
buffer<int32_t, 2> bufC(C.get_data(), range<2>(M, N));
35+
buffer<T2, 2> bufA(A.get_data(), range<2>(M, K));
36+
buffer<T2, 2> bufB(B.get_data(), range<2>(K, N));
37+
buffer<T1, 2> bufC(C.get_data(), range<2>(M, N));
4938

5039
queue q;
5140
q.submit([&](handler &cgh) {
52-
auto accC = bufC.get_access<access::mode::read_write>(cgh);
53-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
54-
auto accB = bufB.get_access<access::mode::read_write>(cgh);
41+
accessor accC{bufC, cgh, read_write};
42+
accessor accA{bufA, cgh, read_write};
43+
accessor accB{bufB, cgh, read_write};
5544

5645
cgh.parallel_for(
5746
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
58-
[accA, accB, accC, M, N,
59-
K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
47+
[=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
6048
// The submatrix API has to be accessed by all the workitems in a
6149
// subgroup these functions will be called once by the subgroup no
6250
// code divergence between the workitems
@@ -66,21 +54,12 @@ void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
6654
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
6755

6856
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
69-
joint_matrix<sub_group, int8_t, use::a, TM, TK, layout::row_major>
70-
sub_a;
71-
57+
joint_matrix<sub_group, T2, use::a, TM, TK, layout::row_major> sub_a;
7258
// For B, we assume B has been already VNNIed.
73-
joint_matrix<sub_group, int8_t, use::b, TK, TN,
59+
joint_matrix<sub_group, T2, use::b, TK, TN,
7460
ext::intel::experimental::matrix::layout::packed>
7561
sub_b;
76-
77-
joint_matrix<sub_group, int32_t, use::accumulator, TM, TN> sub_c;
78-
79-
joint_matrix_load(
80-
sg, sub_c,
81-
accC.template get_multi_ptr<access::decorated::no>() +
82-
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
83-
N, layout::row_major);
62+
joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
8463

8564
joint_matrix_load(
8665
sg, sub_a,
@@ -96,14 +75,19 @@ void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
9675
joint_matrix_load(
9776
sg, sub_b,
9877
accB.template get_multi_ptr<access::decorated::no>() +
99-
+sg_starty / SG_SZ * TN * 4,
100-
N * 4);
78+
sg_starty / SG_SZ * TN * vnniFactor,
79+
N * vnniFactor);
10180
auto wi_slice_b =
10281
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);
10382
for (int i = 0; i < wi_slice_b.length(); i++) {
10483
wi_slice_b[i] += 1;
10584
}
10685

86+
joint_matrix_load(
87+
sg, sub_c,
88+
accC.template get_multi_ptr<access::decorated::no>() +
89+
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
90+
N, layout::row_major);
10791
auto wi_slice_c =
10892
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
10993
for (int i = 0; i < wi_slice_c.length(); i++) {
@@ -113,15 +97,18 @@ void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
11397
}).wait();
11498
}
11599

116-
int8_t A[TM][TK];
117-
int8_t B[TK / 4][TN * 4];
118-
int32_t C[TM][TN];
119-
120100
int main() {
101+
static constexpr unsigned vnniFactor = 4;
102+
103+
int8_t A[TM][TK];
104+
int8_t B[TK / vnniFactor][TN * vnniFactor];
105+
int32_t C[TM][TN];
106+
121107
big_matrix<int32_t, TM, TN> MC((int32_t *)&C);
122108
big_matrix<int8_t, TM, TK> MA((int8_t *)&A);
123-
big_matrix<int8_t, TK / 4, TN * 4> MB((int8_t *)&B);
124-
matrix_elem_wise_ops(MC, MA, MB);
109+
big_matrix<int8_t, TK / vnniFactor, TN * vnniFactor> MB((int8_t *)&B);
110+
111+
matrix_elem_wise_ops<int32_t, int8_t, TM, TN, TK, vnniFactor>(MC, MA, MB);
125112

126113
return 0;
127114
}

0 commit comments

Comments
 (0)