@@ -10,12 +10,6 @@ static float make_fp32(uint16_t x) {
10
10
return *res;
11
11
}
12
12
13
- static uint16_t make_bf16 (float x) {
14
- int *res = reinterpret_cast <int *>(&x);
15
- *res = *res >> 16 ;
16
- return (uint16_t )*res;
17
- }
18
-
19
13
template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
20
14
public:
21
15
T *mat;
@@ -40,7 +34,7 @@ void assert_ops_ref(
40
34
template <typename T, size_t M, size_t N>
41
35
void matrix_verify_add (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
42
36
const float ref) {
43
- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
37
+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
44
38
45
39
q.submit ([&](handler &cgh) {
46
40
auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -55,12 +49,13 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
55
49
sub_group sg = spmd_item.get_sub_group ();
56
50
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
57
51
58
- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
52
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
59
53
60
54
auto wi_slice_a = get_wi_data (sg, sub_a);
61
55
for (int i = 0 ; i < wi_slice_a.length (); i++) {
62
- wi_slice_a[i] = wi_slice_a[i] + make_bf16 (2 );
56
+ wi_slice_a[i] = wi_slice_a[i] + bfloat16 (2 );
63
57
}
58
+
64
59
ext::intel::experimental::matrix::joint_matrix_store (
65
60
sg, sub_a,
66
61
accA.get_pointer () + (sg_startx * TM) * N +
@@ -74,7 +69,7 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
74
69
template <typename T, size_t M, size_t N>
75
70
void matrix_verify_sub (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
76
71
const float ref) {
77
- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
72
+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
78
73
79
74
q.submit ([&](handler &cgh) {
80
75
auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -89,11 +84,11 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
89
84
sub_group sg = spmd_item.get_sub_group ();
90
85
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
91
86
92
- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
87
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
93
88
94
89
auto wi_slice_a = get_wi_data (sg, sub_a);
95
90
for (int i = 0 ; i < wi_slice_a.length (); i++) {
96
- wi_slice_a[i] = wi_slice_a[i] - make_bf16 (2 );
91
+ wi_slice_a[i] = wi_slice_a[i] - bfloat16 (2 );
97
92
}
98
93
ext::intel::experimental::matrix::joint_matrix_store (
99
94
sg, sub_a,
@@ -108,7 +103,7 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
108
103
template <typename T, size_t M, size_t N>
109
104
void matrix_verify_mul (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
110
105
const float ref) {
111
- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
106
+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
112
107
113
108
q.submit ([&](handler &cgh) {
114
109
auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -122,11 +117,11 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
122
117
123
118
sub_group sg = spmd_item.get_sub_group ();
124
119
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
125
- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
120
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
126
121
127
122
auto wi_slice_a = get_wi_data (sg, sub_a);
128
123
for (int i = 0 ; i < wi_slice_a.length (); i++) {
129
- wi_slice_a[i] = wi_slice_a[i] * make_bf16 (3.0 );
124
+ wi_slice_a[i] = wi_slice_a[i] * bfloat16 (3.0 );
130
125
}
131
126
ext::intel::experimental::matrix::joint_matrix_store (
132
127
sg, sub_a,
@@ -141,7 +136,7 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
141
136
template <typename T, size_t M, size_t N>
142
137
void matrix_verify_div (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
143
138
const float ref) {
144
- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
139
+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
145
140
146
141
q.submit ([&](handler &cgh) {
147
142
auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -156,11 +151,11 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
156
151
sub_group sg = spmd_item.get_sub_group ();
157
152
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
158
153
159
- joint_matrix_fill (sg, sub_a, make_bf16 (4.0 ));
154
+ joint_matrix_fill (sg, sub_a, bfloat16 (4.0 ));
160
155
161
156
auto wi_slice_a = get_wi_data (sg, sub_a);
162
157
for (int i = 0 ; i < wi_slice_a.length (); i++) {
163
- wi_slice_a[i] = wi_slice_a[i] / make_bf16 (2.0 );
158
+ wi_slice_a[i] = wi_slice_a[i] / bfloat16 (2.0 );
164
159
}
165
160
ext::intel::experimental::matrix::joint_matrix_store (
166
161
sg, sub_a,
@@ -175,7 +170,7 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
175
170
template <typename T, size_t M, size_t N>
176
171
void matrix_verify_logic (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
177
172
const float ref) {
178
- buffer<unsigned short , 2 > bufA (A.get_data (), range<2 >(M, N));
173
+ buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
179
174
180
175
q.submit ([&](handler &cgh) {
181
176
auto accA = bufA.get_access <access::mode::read_write>(cgh);
@@ -189,26 +184,26 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
189
184
sub_group sg = spmd_item.get_sub_group ();
190
185
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
191
186
192
- joint_matrix_fill (sg, sub_a, make_bf16 (5.0 ));
187
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
193
188
194
189
auto wi_slice_a = get_wi_data (sg, sub_a);
195
190
for (int i = 0 ; i < wi_slice_a.length (); i++) {
196
191
if (wi_slice_a[i]) {
197
- if (wi_slice_a[i] > make_bf16 (2.0 ) ||
198
- wi_slice_a[i] >= make_bf16 (2.0 ) ||
199
- wi_slice_a[i] < make_bf16 (2.0 ) ||
200
- wi_slice_a[i] <= make_bf16 (2.0 )) {
201
- T val = (wi_slice_a[i] != make_bf16 (2.0 )) ? wi_slice_a[i]
202
- : make_bf16 (2.0 );
203
- val = make_bf16 (make_fp32 (val) - static_cast <float >(1 ));
204
- val = make_bf16 (make_fp32 (val) + static_cast <float >(1 ));
205
- if (wi_slice_a[i] == make_bf16 (2.0 )) {
206
- val = make_bf16 (make_fp32 (val) - static_cast <float >(2 ));
207
- val = make_bf16 (make_fp32 (val) * static_cast <float >(3 ));
208
- val = make_bf16 (make_fp32 (val) / static_cast <float >(2 ));
192
+ if (wi_slice_a[i] > bfloat16 (2.0 ) ||
193
+ wi_slice_a[i] >= bfloat16 (2.0 ) ||
194
+ wi_slice_a[i] < bfloat16 (2.0 ) ||
195
+ wi_slice_a[i] <= bfloat16 (2.0 )) {
196
+ T val = (wi_slice_a[i] != bfloat16 (2.0 )) ? wi_slice_a[i]
197
+ : bfloat16 (2.0 );
198
+ val = bfloat16 (make_fp32 (val) - static_cast <float >(1 ));
199
+ val = bfloat16 (make_fp32 (val) + static_cast <float >(1 ));
200
+ if (wi_slice_a[i] == bfloat16 (2.0 )) {
201
+ val = bfloat16 (make_fp32 (val) - static_cast <float >(2 ));
202
+ val = bfloat16 (make_fp32 (val) * static_cast <float >(3 ));
203
+ val = bfloat16 (make_fp32 (val) / static_cast <float >(2 ));
209
204
210
205
} else {
211
- val = make_bf16 (make_fp32 (val) + static_cast <float >(2 ));
206
+ val = bfloat16 (make_fp32 (val) + static_cast <float >(2 ));
212
207
}
213
208
wi_slice_a[i] = val;
214
209
}
@@ -226,7 +221,7 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
226
221
227
222
static constexpr size_t MATRIX_M = TM * 2 ;
228
223
static constexpr size_t MATRIX_N = TN * 2 ;
229
- unsigned short A[MATRIX_M][MATRIX_N];
224
+ bfloat16 A[MATRIX_M][MATRIX_N];
230
225
float D[MATRIX_M][MATRIX_N];
231
226
232
227
void matrix_ops_ref (float *D, int M, int N) {
@@ -240,18 +235,18 @@ void matrix_ops_ref(float *D, int M, int N) {
240
235
int main () {
241
236
242
237
big_matrix<float , MATRIX_M, MATRIX_N> MD ((float *)&D);
243
- big_matrix<unsigned short , MATRIX_M, MATRIX_N> MA ((unsigned short *)&A);
238
+ big_matrix<bfloat16 , MATRIX_M, MATRIX_N> MA ((bfloat16 *)&A);
244
239
245
240
size_t NDRangeM = MATRIX_M / TM;
246
241
size_t NDRangeN = MATRIX_N / TN;
247
242
queue q;
248
243
nd_range<2 > r ({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ});
249
244
250
- matrix_verify_add<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
251
- matrix_verify_sub<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 3.0 );
252
- matrix_verify_mul<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 15.0 );
253
- matrix_verify_div<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 2.0 );
254
- matrix_verify_logic<unsigned short , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
245
+ matrix_verify_add<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
246
+ matrix_verify_sub<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 3.0 );
247
+ matrix_verify_mul<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 15.0 );
248
+ matrix_verify_div<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 2.0 );
249
+ matrix_verify_logic<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
255
250
256
251
return 0 ;
257
252
}
0 commit comments