@@ -22,4 +22,164 @@ using bfloat16 = sycl::ext::oneapi::bfloat16;
22
22
23
23
#define SG_SZ 16
24
24
25
+ <<<<<<< HEAD
25
26
#include " joint_matrix_bfloat16_32x64_impl.hpp"
27
+ =======
28
+ #define TM 32
29
+ #define TN 64
30
+ #define TK 16
31
+
32
+ #define BF16_EPSILON 0.00781250
33
+
34
+ template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
35
+ private:
36
+ T *mat;
37
+
38
+ public:
39
+ T *get_data () { return mat; }
40
+ void set_data (T *data) { mat = data; }
41
+ big_matrix (T *data) : mat(data) {}
42
+ };
43
+
44
+ template <typename T1, typename T2, size_t M, size_t N, size_t K>
45
+ void matrix_multiply (big_matrix<T1, M, N> &C, big_matrix<T2, M, K> &A,
46
+ big_matrix<T2, K / 2 , N * 2 > &B) {
47
+ size_t NDRangeM = M / TM;
48
+ size_t NDRangeN = N / TN;
49
+ buffer<bfloat16, 2 > bufA (A.get_data (), range<2 >(M, K));
50
+ buffer<bfloat16, 2 > bufB (B.get_data (), range<2 >(K, N));
51
+ buffer<float , 2 > bufC ((float *)C.get_data (), range<2 >(M, N));
52
+
53
+ queue q;
54
+ q.submit ([&](handler &cgh) {
55
+ auto accC = bufC.get_access <access::mode::read_write>(cgh);
56
+ auto accA = bufA.get_access <access::mode::read_write>(cgh);
57
+ auto accB = bufB.get_access <access::mode::read_write>(cgh);
58
+
59
+ cgh.parallel_for <class imatrix >(
60
+ nd_range<2 >({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ}),
61
+ [=](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]]
62
+
63
+ {
64
+ // The submatrix API has to be accessed by all the workitems in a
65
+ // subgroup these functions will be called once by the subgroup no
66
+ // code divergence between the workitems
67
+ const auto global_idx = spmd_item.get_global_id (0 );
68
+ const auto global_idy = spmd_item.get_global_id (1 );
69
+ const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
70
+ const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
71
+
72
+ ext::oneapi::sub_group sg = spmd_item.get_sub_group ();
73
+ joint_matrix<bfloat16, TM, TK> sub_a (sg);
74
+ // For B, since current implementation does not support non-packed
75
+ // layout, users need to specify the updated VNNI sizes along with
76
+ // the packed_b layout. By default, the layout is row_major and size
77
+ // is (TK, TN).
78
+ joint_matrix<bfloat16, TK, TN, matrix_layout::packed_b> sub_b (sg);
79
+ joint_matrix<float , TM, TN> sub_c (sg);
80
+
81
+ joint_matrix_load (sg, sub_c,
82
+ accC.get_pointer () + (sg_startx * TM) * N +
83
+ sg_starty / SG_SZ * TN,
84
+ N, matrix_layout::row_major);
85
+ for (int k = 0 ; k < K / TK; k += 1 ) { //
86
+ joint_matrix_load (
87
+ sg, sub_a, accA.get_pointer () + (sg_startx * TM) * K + k * TK,
88
+ K, matrix_layout::row_major);
89
+ // Assuming B data is already in VNNI format.
90
+ joint_matrix_load (sg, sub_b,
91
+ accB.get_pointer () + (k * TK / 2 ) * (N * 2 ) +
92
+ sg_starty / SG_SZ * TN * 2 ,
93
+ N * 2 , matrix_layout::packed_b);
94
+ sub_c = joint_matrix_mad (sg, sub_a, sub_b, sub_c);
95
+ }
96
+ joint_matrix_store (sg, sub_c,
97
+ accC.get_pointer () + (sg_startx * TM) * N +
98
+ sg_starty / SG_SZ * TN,
99
+ N, matrix_layout::row_major);
100
+ }); // parallel for
101
+ }).wait ();
102
+ }
103
+
104
+ static constexpr size_t MATRIX_M = TM * 2 ;
105
+ static constexpr size_t MATRIX_N = TN * 2 ;
106
+ static constexpr size_t MATRIX_K = TK * 2 ;
107
+ bfloat16 A[MATRIX_M][MATRIX_K];
108
+ bfloat16 B[MATRIX_K / 2 ][MATRIX_N * 2 ];
109
+ unsigned short Aref[MATRIX_M][MATRIX_K];
110
+ unsigned short Bref[MATRIX_K / 2 ][MATRIX_N * 2 ];
111
+ float C[MATRIX_M][MATRIX_N];
112
+ float D[MATRIX_M][MATRIX_N];
113
+
114
+ float make_fp32 (short x) {
115
+ unsigned int y = x;
116
+ y = y << 16 ;
117
+ float *res = reinterpret_cast <float *>(&y);
118
+ return *res;
119
+ }
120
+
121
+ unsigned short make_bf16 (float x) {
122
+ int *res = reinterpret_cast <int *>(&x);
123
+ *res = *res >> 16 ;
124
+ return (unsigned short )*res;
125
+ }
126
+
127
+ void matrix_multiply_ref (int *A_mem, int *B_mem, int *C_mem, int M, int N,
128
+ int K) {
129
+ // tiling
130
+ for (int m = 0 ; m < M; m++)
131
+ for (int n = 0 ; n < N; n++) {
132
+ for (int k = 0 ; k < K; k++) {
133
+ short *va = (short *)(A_mem + m * K + k);
134
+ short *vb = (short *)(B_mem + k * N + n);
135
+ float acc = *((float *)(C_mem + m * N + n));
136
+ // FIXME: Should we do reduce-add in another version?
137
+ for (int i = 0 ; i < 2 ; i++) {
138
+ acc += (make_fp32 (va[i]) * make_fp32 (vb[i]));
139
+ }
140
+ *((float *)(C_mem + m * N + n)) = acc;
141
+ }
142
+ }
143
+ }
144
+
145
+ int main () {
146
+ for (int i = 0 ; i < MATRIX_M; i++) {
147
+ for (int j = 0 ; j < MATRIX_K; j++) {
148
+ // bfloat16 is created using unsigned short since conversion from float to
149
+ // bfloat16 is not supported on the host side yet
150
+ A[i][j] = make_bf16 (1 .0f * (i + j));
151
+ Aref[i][j] = make_bf16 (1 .0f * (i + j));
152
+ }
153
+ }
154
+ for (int i = 0 ; i < MATRIX_K / 2 ; i++) {
155
+ for (int j = 0 ; j < MATRIX_N * 2 ; j++) {
156
+ B[i][j] = make_bf16 (2 .0f * i + 3 .0f * j);
157
+ Bref[i][j] = make_bf16 (2 .0f * i + 3 .0f * j);
158
+ }
159
+ }
160
+ for (int i = 0 ; i < MATRIX_M; i++) {
161
+ for (int j = 0 ; j < MATRIX_N; j++) {
162
+ C[i][j] = 1.0 ;
163
+ D[i][j] = 1.0 ;
164
+ }
165
+ }
166
+
167
+ big_matrix<float , MATRIX_M, MATRIX_N> MC ((float *)&C);
168
+ big_matrix<float , MATRIX_M, MATRIX_N> MD ((float *)&D);
169
+ big_matrix<bfloat16, MATRIX_M, MATRIX_K> MA ((bfloat16 *)&A);
170
+ big_matrix<bfloat16, MATRIX_K / 2 , MATRIX_N * 2 > MB ((bfloat16 *)&B);
171
+ matrix_multiply (MC, MA, MB);
172
+ matrix_multiply_ref ((int32_t *)Aref, (int32_t *)Bref, (int32_t *)D, MATRIX_M,
173
+ MATRIX_N, MATRIX_K / 2 );
174
+
175
+ bool res = true ;
176
+ for (int i = 0 ; i < MATRIX_M; i++) {
177
+ for (int j = 0 ; j < MATRIX_N; j++) {
178
+ if ((fabs (C[i][j]) - fabs (D[i][j])) > BF16_EPSILON)
179
+ res = false ;
180
+ }
181
+ }
182
+ std::cout << (res ? " passed" : " failed" ) << std::endl;
183
+ return !res;
184
+ }
185
+ >>>>>>> 45daa9b4d ([SYCL] Fix result check in joint matrix tests (#1432 ))
0 commit comments