1
- // ==----------- element_wise_abc_impl.hpp - DPC++ joint_matrix-------------
2
- // ----==//
1
+ // ==----------- element_wise_abc_impl.hpp - DPC++ joint_matrix-------------==//
3
2
//
4
3
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5
4
// See https://llvm.org/LICENSE.txt for license information.
@@ -27,36 +26,25 @@ template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
27
26
big_matrix (T *data) : mat(data) {}
28
27
};
29
28
30
- template <typename T1, typename T2, size_t NUM_ROWS_A, size_t NUM_COLS_A,
31
- size_t NUM_ROWS_B, size_t NUM_COLS_B, size_t NUM_ROWS_C,
32
- size_t NUM_COLS_C>
33
- void matrix_elem_wise_ops (big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
34
- big_matrix<T2, NUM_ROWS_A, NUM_COLS_A> &A,
35
- big_matrix<T2, NUM_ROWS_B, NUM_COLS_B> &B) {
36
- size_t M = NUM_ROWS_C;
37
- size_t N = NUM_COLS_C;
38
- size_t K = NUM_COLS_A;
39
-
40
- // B => K/4 x N*4, A => M x K, C => M, N
41
- // stride should be X's cols, e.g., B's stirde = N*4
42
- assert (NUM_ROWS_C == NUM_ROWS_A && NUM_COLS_A == NUM_ROWS_B * 4 );
43
-
29
+ template <typename T1, typename T2, size_t M, size_t N, size_t K,
30
+ int vnniFactor>
31
+ void matrix_elem_wise_ops (big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
32
+ big_matrix<T2, K / vnniFactor, N * vnniFactor> &B) {
44
33
size_t NDRangeM = M / TM;
45
34
size_t NDRangeN = N / TN;
46
- buffer<int8_t , 2 > bufA (A.get_data (), range<2 >(M, K));
47
- buffer<int8_t , 2 > bufB (B.get_data (), range<2 >(K, N));
48
- buffer<int32_t , 2 > bufC (C.get_data (), range<2 >(M, N));
35
+ buffer<T2 , 2 > bufA (A.get_data (), range<2 >(M, K));
36
+ buffer<T2 , 2 > bufB (B.get_data (), range<2 >(K, N));
37
+ buffer<T1 , 2 > bufC (C.get_data (), range<2 >(M, N));
49
38
50
39
queue q;
51
40
q.submit ([&](handler &cgh) {
52
- auto accC = bufC. get_access <access::mode:: read_write>(cgh) ;
53
- auto accA = bufA. get_access <access::mode:: read_write>(cgh) ;
54
- auto accB = bufB. get_access <access::mode:: read_write>(cgh) ;
41
+ accessor accC{bufC, cgh, read_write} ;
42
+ accessor accA{bufA, cgh, read_write} ;
43
+ accessor accB{bufB, cgh, read_write} ;
55
44
56
45
cgh.parallel_for (
57
46
nd_range<2 >({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ}),
58
- [accA, accB, accC, M, N,
59
- K](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
47
+ [=](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
60
48
// The submatrix API has to be accessed by all the workitems in a
61
49
// subgroup these functions will be called once by the subgroup no
62
50
// code divergence between the workitems
@@ -66,21 +54,12 @@ void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
66
54
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
67
55
68
56
ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
69
- joint_matrix<sub_group, int8_t , use::a, TM, TK, layout::row_major>
70
- sub_a;
71
-
57
+ joint_matrix<sub_group, T2, use::a, TM, TK, layout::row_major> sub_a;
72
58
// For B, we assume B has been already VNNIed.
73
- joint_matrix<sub_group, int8_t , use::b, TK, TN,
59
+ joint_matrix<sub_group, T2 , use::b, TK, TN,
74
60
ext::intel::experimental::matrix::layout::packed>
75
61
sub_b;
76
-
77
- joint_matrix<sub_group, int32_t , use::accumulator, TM, TN> sub_c;
78
-
79
- joint_matrix_load (
80
- sg, sub_c,
81
- accC.template get_multi_ptr <access::decorated::no>() +
82
- (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
83
- N, layout::row_major);
62
+ joint_matrix<sub_group, T1, use::accumulator, TM, TN> sub_c;
84
63
85
64
joint_matrix_load (
86
65
sg, sub_a,
@@ -96,14 +75,19 @@ void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
96
75
joint_matrix_load (
97
76
sg, sub_b,
98
77
accB.template get_multi_ptr <access::decorated::no>() +
99
- + sg_starty / SG_SZ * TN * 4 ,
100
- N * 4 );
78
+ sg_starty / SG_SZ * TN * vnniFactor ,
79
+ N * vnniFactor );
101
80
auto wi_slice_b =
102
81
sycl::ext::intel::experimental::matrix::get_wi_data (sg, sub_b);
103
82
for (int i = 0 ; i < wi_slice_b.length (); i++) {
104
83
wi_slice_b[i] += 1 ;
105
84
}
106
85
86
+ joint_matrix_load (
87
+ sg, sub_c,
88
+ accC.template get_multi_ptr <access::decorated::no>() +
89
+ (sg_startx * TM) * N + sg_starty / SG_SZ * TN,
90
+ N, layout::row_major);
107
91
auto wi_slice_c =
108
92
sycl::ext::intel::experimental::matrix::get_wi_data (sg, sub_c);
109
93
for (int i = 0 ; i < wi_slice_c.length (); i++) {
@@ -113,15 +97,18 @@ void matrix_elem_wise_ops(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
113
97
}).wait ();
114
98
}
115
99
116
- int8_t A[TM][TK];
117
- int8_t B[TK / 4 ][TN * 4 ];
118
- int32_t C[TM][TN];
119
-
120
100
int main () {
101
+ static constexpr unsigned vnniFactor = 4 ;
102
+
103
+ int8_t A[TM][TK];
104
+ int8_t B[TK / vnniFactor][TN * vnniFactor];
105
+ int32_t C[TM][TN];
106
+
121
107
big_matrix<int32_t , TM, TN> MC ((int32_t *)&C);
122
108
big_matrix<int8_t , TM, TK> MA ((int8_t *)&A);
123
- big_matrix<int8_t , TK / 4 , TN * 4 > MB ((int8_t *)&B);
124
- matrix_elem_wise_ops (MC, MA, MB);
109
+ big_matrix<int8_t , TK / vnniFactor, TN * vnniFactor> MB ((int8_t *)&B);
110
+
111
+ matrix_elem_wise_ops<int32_t , int8_t , TM, TN, TK, vnniFactor>(MC, MA, MB);
125
112
126
113
return 0 ;
127
114
}
0 commit comments