@@ -25,35 +25,42 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> C,
25
25
const float ref) {
26
26
for (size_t i = 0 ; i < M; i++)
27
27
for (size_t j = 0 ; j < N; j++) {
28
- auto diff = make_fp32 (C[i][j]) - ref;
28
+ float diff;
29
+ if constexpr (std::is_same_v<T, bfloat16>)
30
+ diff = make_fp32 (C[i][j]) - ref;
31
+ else
32
+ diff = C[i][j] - ref;
29
33
assert (std::fabs (static_cast <float >(diff)) <
30
34
std::numeric_limits<float >::epsilon ());
31
35
}
32
36
}
33
37
template <typename T, size_t M, size_t N>
34
38
void matrix_verify_add (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
35
39
const float ref) {
36
- buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
40
+ buffer<T , 2 > bufA (A.get_data (), range<2 >(M, N));
37
41
38
42
q.submit ([&](handler &cgh) {
39
- auto accA = bufA.get_access <access::mode::read_write>(cgh);
40
-
41
- cgh.parallel_for <class add_matrix >(
42
- r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
43
+ sycl::accessor accA{bufA, cgh, sycl::read_write};
44
+ cgh.parallel_for (
45
+ r, [accA](nd_item<2 > spmd_item)[[intel::reqd_sub_group_size (SG_SZ)]] {
43
46
const auto global_idx = spmd_item.get_global_id (0 );
44
47
const auto global_idy = spmd_item.get_global_id (1 );
45
48
const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
46
49
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
47
50
48
51
sub_group sg = spmd_item.get_sub_group ();
49
52
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
50
-
51
- joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
52
-
53
+ if constexpr (std::is_same_v<T, bfloat16>)
54
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
55
+ else
56
+ joint_matrix_fill (sg, sub_a, 5 );
53
57
auto wi_slice_a =
54
58
sycl::ext::intel::experimental::matrix::get_wi_data (sg, sub_a);
55
59
for (int i = 0 ; i < wi_slice_a.length (); i++) {
56
- wi_slice_a[i] = wi_slice_a[i] + bfloat16 (2 );
60
+ if constexpr (std::is_same_v<T, bfloat16>)
61
+ wi_slice_a[i] = wi_slice_a[i] + bfloat16 (2 );
62
+ else
63
+ wi_slice_a[i] = wi_slice_a[i] + 2 ;
57
64
}
58
65
59
66
ext::intel::experimental::matrix::joint_matrix_store (
@@ -62,154 +69,188 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
62
69
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
63
70
N);
64
71
}); // parallel for
65
- }).wait ();
72
+ })
73
+ .wait ();
66
74
assert_ops_ref<T, M, N>(bufA.get_host_access (read_only), ref);
67
75
}
68
76
69
77
template <typename T, size_t M, size_t N>
70
78
void matrix_verify_sub (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
71
79
const float ref) {
72
- buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
80
+ buffer<T , 2 > bufA (A.get_data (), range<2 >(M, N));
73
81
74
82
q.submit ([&](handler &cgh) {
75
- auto accA = bufA.get_access <access::mode::read_write>(cgh);
76
-
77
- cgh.parallel_for <class sub_matrix >(
78
- r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
83
+ sycl::accessor accA{bufA, cgh, sycl::read_write};
84
+ cgh.parallel_for (
85
+ r, [accA](nd_item<2 > spmd_item)[[intel::reqd_sub_group_size (SG_SZ)]] {
79
86
const auto global_idx = spmd_item.get_global_id (0 );
80
87
const auto global_idy = spmd_item.get_global_id (1 );
81
88
const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
82
89
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
83
90
84
91
sub_group sg = spmd_item.get_sub_group ();
85
92
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
86
-
87
- joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
88
-
93
+ if constexpr (std::is_same_v<T, bfloat16>)
94
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
95
+ else
96
+ joint_matrix_fill (sg, sub_a, 5 );
89
97
auto wi_slice_a =
90
98
sycl::ext::intel::experimental::matrix::get_wi_data (sg, sub_a);
91
99
for (int i = 0 ; i < wi_slice_a.length (); i++) {
92
- wi_slice_a[i] = wi_slice_a[i] - bfloat16 (2 );
100
+ if constexpr (std::is_same_v<T, bfloat16>)
101
+ wi_slice_a[i] = wi_slice_a[i] - bfloat16 (2 );
102
+ else
103
+ wi_slice_a[i] = wi_slice_a[i] - 2 ;
93
104
}
94
105
ext::intel::experimental::matrix::joint_matrix_store (
95
106
sg, sub_a,
96
107
accA.template get_multi_ptr <access::decorated::no>() +
97
108
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
98
109
N);
99
110
}); // parallel for
100
- }).wait ();
111
+ })
112
+ .wait ();
101
113
assert_ops_ref<T, M, N>(bufA.get_host_access (read_only), ref);
102
114
}
103
115
104
116
template <typename T, size_t M, size_t N>
105
117
void matrix_verify_mul (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
106
118
const float ref) {
107
- buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
119
+ buffer<T , 2 > bufA (A.get_data (), range<2 >(M, N));
108
120
109
121
q.submit ([&](handler &cgh) {
110
- auto accA = bufA.get_access <access::mode::read_write>(cgh);
111
-
112
- cgh.parallel_for <class mul_matrix >(
113
- r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
122
+ sycl::accessor accA{bufA, cgh, sycl::read_write};
123
+ cgh.parallel_for (
124
+ r, [accA](nd_item<2 > spmd_item)[[intel::reqd_sub_group_size (SG_SZ)]] {
114
125
const auto global_idx = spmd_item.get_global_id (0 );
115
126
const auto global_idy = spmd_item.get_global_id (1 );
116
127
const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
117
128
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
118
129
119
130
sub_group sg = spmd_item.get_sub_group ();
120
131
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
121
- joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
122
-
132
+ if constexpr (std::is_same_v<T, bfloat16>)
133
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
134
+ else
135
+ joint_matrix_fill (sg, sub_a, 5 );
123
136
auto wi_slice_a =
124
137
sycl::ext::intel::experimental::matrix::get_wi_data (sg, sub_a);
125
138
for (int i = 0 ; i < wi_slice_a.length (); i++) {
126
- wi_slice_a[i] = wi_slice_a[i] * bfloat16 (3.0 );
139
+ if constexpr (std::is_same_v<T, bfloat16>)
140
+ wi_slice_a[i] = wi_slice_a[i] * bfloat16 (3.0 );
141
+ else
142
+ wi_slice_a[i] = wi_slice_a[i] * 3.0 ;
127
143
}
128
144
ext::intel::experimental::matrix::joint_matrix_store (
129
145
sg, sub_a,
130
146
accA.template get_multi_ptr <access::decorated::no>() +
131
147
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
132
148
N);
133
149
}); // parallel for
134
- }).wait ();
150
+ })
151
+ .wait ();
135
152
assert_ops_ref<T, M, N>(bufA.get_host_access (read_only), ref);
136
153
}
137
154
138
155
template <typename T, size_t M, size_t N>
139
156
void matrix_verify_div (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
140
157
const float ref) {
141
- buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
158
+ buffer<T , 2 > bufA (A.get_data (), range<2 >(M, N));
142
159
143
160
q.submit ([&](handler &cgh) {
144
- auto accA = bufA.get_access <access::mode::read_write>(cgh);
145
-
146
- cgh.parallel_for <class div_matrix >(
147
- r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
161
+ sycl::accessor accA{bufA, cgh, sycl::read_write};
162
+ cgh.parallel_for (
163
+ r, [accA](nd_item<2 > spmd_item)[[intel::reqd_sub_group_size (SG_SZ)]] {
148
164
const auto global_idx = spmd_item.get_global_id (0 );
149
165
const auto global_idy = spmd_item.get_global_id (1 );
150
166
const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
151
167
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
152
168
153
169
sub_group sg = spmd_item.get_sub_group ();
154
170
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
155
-
156
- joint_matrix_fill (sg, sub_a, bfloat16 (4.0 ));
157
-
171
+ if constexpr (std::is_same_v<T, bfloat16>)
172
+ joint_matrix_fill (sg, sub_a, bfloat16 (4.0 ));
173
+ else
174
+ joint_matrix_fill (sg, sub_a, 4 );
158
175
auto wi_slice_a =
159
176
sycl::ext::intel::experimental::matrix::get_wi_data (sg, sub_a);
160
177
for (int i = 0 ; i < wi_slice_a.length (); i++) {
161
- wi_slice_a[i] = wi_slice_a[i] / bfloat16 (2.0 );
178
+ if constexpr (std::is_same_v<T, bfloat16>)
179
+ wi_slice_a[i] = wi_slice_a[i] / bfloat16 (2.0 );
180
+ else
181
+ wi_slice_a[i] = wi_slice_a[i] / 2.0 ;
162
182
}
163
183
ext::intel::experimental::matrix::joint_matrix_store (
164
184
sg, sub_a,
165
185
accA.template get_multi_ptr <access::decorated::no>() +
166
186
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
167
187
N);
168
188
}); // parallel for
169
- }).wait ();
189
+ })
190
+ .wait ();
170
191
assert_ops_ref<T, M, N>(bufA.get_host_access (read_only), ref);
171
192
}
172
193
173
194
template <typename T, size_t M, size_t N>
174
195
void matrix_verify_logic (queue q, big_matrix<T, M, N> &A, nd_range<2 > &r,
175
196
const float ref) {
176
- buffer<bfloat16 , 2 > bufA (A.get_data (), range<2 >(M, N));
197
+ buffer<T , 2 > bufA (A.get_data (), range<2 >(M, N));
177
198
178
199
q.submit ([&](handler &cgh) {
179
- auto accA = bufA. get_access <access::mode:: read_write>(cgh) ;
180
- cgh.parallel_for < class logic_matrix > (
181
- r, [accA](nd_item<2 > spmd_item) [[intel::reqd_sub_group_size (SG_SZ)]] {
200
+ sycl::accessor accA{bufA, cgh, sycl:: read_write} ;
201
+ cgh.parallel_for (
202
+ r, [accA](nd_item<2 > spmd_item)[[intel::reqd_sub_group_size (SG_SZ)]] {
182
203
const auto global_idx = spmd_item.get_global_id (0 );
183
204
const auto global_idy = spmd_item.get_global_id (1 );
184
205
const auto sg_startx = global_idx - spmd_item.get_local_id (0 );
185
206
const auto sg_starty = global_idy - spmd_item.get_local_id (1 );
186
207
187
208
sub_group sg = spmd_item.get_sub_group ();
188
209
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
189
-
190
- joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
191
-
210
+ if constexpr (std::is_same_v<T, bfloat16>)
211
+ joint_matrix_fill (sg, sub_a, bfloat16 (5.0 ));
212
+ else
213
+ joint_matrix_fill (sg, sub_a, 5 );
192
214
auto wi_slice_a =
193
215
sycl::ext::intel::experimental::matrix::get_wi_data (sg, sub_a);
194
216
for (int i = 0 ; i < wi_slice_a.length (); i++) {
195
217
if (wi_slice_a[i]) {
196
- if (wi_slice_a[i] > bfloat16 (2.0 ) ||
197
- wi_slice_a[i] >= bfloat16 (2.0 ) ||
198
- wi_slice_a[i] < bfloat16 (2.0 ) ||
199
- wi_slice_a[i] <= bfloat16 (2.0 )) {
200
- T val = (wi_slice_a[i] != bfloat16 (2.0 )) ? wi_slice_a[i]
201
- : bfloat16 (2.0 );
202
- val = bfloat16 (make_fp32 (val) - static_cast <float >(1 ));
203
- val = bfloat16 (make_fp32 (val) + static_cast <float >(1 ));
204
- if (wi_slice_a[i] == bfloat16 (2.0 )) {
205
- val = bfloat16 (make_fp32 (val) - static_cast <float >(2 ));
206
- val = bfloat16 (make_fp32 (val) * static_cast <float >(3 ));
207
- val = bfloat16 (make_fp32 (val) / static_cast <float >(2 ));
208
-
209
- } else {
210
- val = bfloat16 (make_fp32 (val) + static_cast <float >(2 ));
218
+ if constexpr (std::is_same_v<T, bfloat16>) {
219
+ if (wi_slice_a[i] > bfloat16 (2.0 ) ||
220
+ wi_slice_a[i] >= bfloat16 (2.0 ) ||
221
+ wi_slice_a[i] < bfloat16 (2.0 ) ||
222
+ wi_slice_a[i] <= bfloat16 (2.0 )) {
223
+ T val = (wi_slice_a[i] != bfloat16 (2.0 )) ? wi_slice_a[i]
224
+ : bfloat16 (2.0 );
225
+ val = bfloat16 (make_fp32 (val) - static_cast <float >(1 ));
226
+ val = bfloat16 (make_fp32 (val) + static_cast <float >(1 ));
227
+ if (wi_slice_a[i] == bfloat16 (2.0 )) {
228
+ val = bfloat16 (make_fp32 (val) - static_cast <float >(2 ));
229
+ val = bfloat16 (make_fp32 (val) * static_cast <float >(3 ));
230
+ val = bfloat16 (make_fp32 (val) / static_cast <float >(2 ));
231
+
232
+ } else {
233
+ val = bfloat16 (make_fp32 (val) + static_cast <float >(2 ));
234
+ }
235
+ wi_slice_a[i] = val;
236
+ }
237
+ } else {
238
+ if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
239
+ wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0 ) {
240
+ T val = (wi_slice_a[i] != 2.0 ) ? wi_slice_a[i]
241
+ : static_cast <T>(2.0 );
242
+ val = val - 1 ;
243
+ val = val + 1 ;
244
+ if (wi_slice_a[i] == 2.0 ) {
245
+ val = val - 2 ;
246
+ val = val * 3 ;
247
+ val = val / 2 ;
248
+
249
+ } else {
250
+ val = val + 2 ;
251
+ }
252
+ wi_slice_a[i] = val;
211
253
}
212
- wi_slice_a[i] = val;
213
254
}
214
255
}
215
256
}
@@ -219,7 +260,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
219
260
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
220
261
N);
221
262
}); // parallel for
222
- }).wait ();
263
+ })
264
+ .wait ();
223
265
assert_ops_ref<T, M, N>(bufA.get_host_access (read_only), ref);
224
266
}
225
267
@@ -236,21 +278,27 @@ void matrix_ops_ref(float *D, int M, int N) {
236
278
}
237
279
}
238
280
239
- int main () {
281
+ template < typename T, typename Tref> int test_ewops () {
240
282
241
- big_matrix<float , MATRIX_M, MATRIX_N> MD ((float *)&D);
242
- big_matrix<bfloat16 , MATRIX_M, MATRIX_N> MA ((bfloat16 *)&A);
283
+ big_matrix<Tref , MATRIX_M, MATRIX_N> MD ((Tref *)&D);
284
+ big_matrix<T , MATRIX_M, MATRIX_N> MA ((T *)&A);
243
285
244
286
size_t NDRangeM = MATRIX_M / TM;
245
287
size_t NDRangeN = MATRIX_N / TN;
246
288
queue q;
247
289
nd_range<2 > r ({NDRangeM, NDRangeN * SG_SZ}, {1 , 1 * SG_SZ});
248
290
249
- matrix_verify_add<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
250
- matrix_verify_sub<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 3.0 );
251
- matrix_verify_mul<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 15.0 );
252
- matrix_verify_div<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 2.0 );
253
- matrix_verify_logic<bfloat16 , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
291
+ matrix_verify_add<T , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
292
+ matrix_verify_sub<T , MATRIX_M, MATRIX_N>(q, MA, r, 3.0 );
293
+ matrix_verify_mul<T , MATRIX_M, MATRIX_N>(q, MA, r, 15.0 );
294
+ matrix_verify_div<T , MATRIX_M, MATRIX_N>(q, MA, r, 2.0 );
295
+ matrix_verify_logic<T , MATRIX_M, MATRIX_N>(q, MA, r, 7.0 );
254
296
255
297
return 0 ;
256
298
}
299
+
300
+ int main () {
301
+ test_ewops<bfloat16, float >();
302
+ test_ewops<float , float >();
303
+ return 0 ;
304
+ }
0 commit comments