Skip to content

Commit ffeecb7

Browse files
authored
[SYCL][Matrix] Add tests for element wise ops on float type (#9679)
1 parent 05d3be6 commit ffeecb7

File tree

3 files changed

+125
-76
lines changed

3 files changed

+125
-76
lines changed

sycl/test-e2e/Matrix/XMX8/element_wise_all_ops_bf16.cpp renamed to sycl/test-e2e/Matrix/XMX8/element_wise_all_ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//==----------- element_wise_all_ops_bf16.cpp - DPC++ joint_matrix---------==//
1+
//==------------ element_wise_all_ops.cpp - DPC++ joint_matrix-------------==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -16,9 +16,10 @@
1616

1717
using namespace sycl;
1818
using namespace sycl::ext::intel;
19+
using namespace sycl::ext::oneapi;
1920
using namespace sycl::ext::oneapi::experimental::matrix;
2021
using bfloat16 = sycl::ext::oneapi::bfloat16;
2122

2223
#define SG_SZ 8
2324

24-
#include "../element_wise_all_ops_bf16_impl.hpp"
25+
#include "../element_wise_all_ops_impl.hpp"

sycl/test-e2e/Matrix/element_wise_all_ops_bf16.cpp renamed to sycl/test-e2e/Matrix/element_wise_all_ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//==----------- element_wise_all_ops_bf16.cpp - DPC++ joint_matrix---------==//
1+
//==------------ element_wise_all_ops.cpp - DPC++ joint_matrix-------------==//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -22,4 +22,4 @@ using bfloat16 = sycl::ext::oneapi::bfloat16;
2222

2323
#define SG_SZ 16
2424

25-
#include "element_wise_all_ops_bf16_impl.hpp"
25+
#include "element_wise_all_ops_impl.hpp"

sycl/test-e2e/Matrix/element_wise_all_ops_bf16_impl.hpp renamed to sycl/test-e2e/Matrix/element_wise_all_ops_impl.hpp

Lines changed: 120 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -25,35 +25,42 @@ void assert_ops_ref(host_accessor<T, 2, access::mode::read> C,
2525
const float ref) {
2626
for (size_t i = 0; i < M; i++)
2727
for (size_t j = 0; j < N; j++) {
28-
auto diff = make_fp32(C[i][j]) - ref;
28+
float diff;
29+
if constexpr (std::is_same_v<T, bfloat16>)
30+
diff = make_fp32(C[i][j]) - ref;
31+
else
32+
diff = C[i][j] - ref;
2933
assert(std::fabs(static_cast<float>(diff)) <
3034
std::numeric_limits<float>::epsilon());
3135
}
3236
}
3337
template <typename T, size_t M, size_t N>
3438
void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
3539
const float ref) {
36-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
40+
buffer<T, 2> bufA(A.get_data(), range<2>(M, N));
3741

3842
q.submit([&](handler &cgh) {
39-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
40-
41-
cgh.parallel_for<class add_matrix>(
42-
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
43+
sycl::accessor accA{bufA, cgh, sycl::read_write};
44+
cgh.parallel_for(
45+
r, [accA](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] {
4346
const auto global_idx = spmd_item.get_global_id(0);
4447
const auto global_idy = spmd_item.get_global_id(1);
4548
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
4649
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
4750

4851
sub_group sg = spmd_item.get_sub_group();
4952
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
50-
51-
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
52-
53+
if constexpr (std::is_same_v<T, bfloat16>)
54+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
55+
else
56+
joint_matrix_fill(sg, sub_a, 5);
5357
auto wi_slice_a =
5458
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
5559
for (int i = 0; i < wi_slice_a.length(); i++) {
56-
wi_slice_a[i] = wi_slice_a[i] + bfloat16(2);
60+
if constexpr (std::is_same_v<T, bfloat16>)
61+
wi_slice_a[i] = wi_slice_a[i] + bfloat16(2);
62+
else
63+
wi_slice_a[i] = wi_slice_a[i] + 2;
5764
}
5865

5966
ext::intel::experimental::matrix::joint_matrix_store(
@@ -62,154 +69,188 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
6269
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
6370
N);
6471
}); // parallel for
65-
}).wait();
72+
})
73+
.wait();
6674
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
6775
}
6876

6977
template <typename T, size_t M, size_t N>
7078
void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
7179
const float ref) {
72-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
80+
buffer<T, 2> bufA(A.get_data(), range<2>(M, N));
7381

7482
q.submit([&](handler &cgh) {
75-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
76-
77-
cgh.parallel_for<class sub_matrix>(
78-
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
83+
sycl::accessor accA{bufA, cgh, sycl::read_write};
84+
cgh.parallel_for(
85+
r, [accA](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] {
7986
const auto global_idx = spmd_item.get_global_id(0);
8087
const auto global_idy = spmd_item.get_global_id(1);
8188
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
8289
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
8390

8491
sub_group sg = spmd_item.get_sub_group();
8592
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
86-
87-
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
88-
93+
if constexpr (std::is_same_v<T, bfloat16>)
94+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
95+
else
96+
joint_matrix_fill(sg, sub_a, 5);
8997
auto wi_slice_a =
9098
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
9199
for (int i = 0; i < wi_slice_a.length(); i++) {
92-
wi_slice_a[i] = wi_slice_a[i] - bfloat16(2);
100+
if constexpr (std::is_same_v<T, bfloat16>)
101+
wi_slice_a[i] = wi_slice_a[i] - bfloat16(2);
102+
else
103+
wi_slice_a[i] = wi_slice_a[i] - 2;
93104
}
94105
ext::intel::experimental::matrix::joint_matrix_store(
95106
sg, sub_a,
96107
accA.template get_multi_ptr<access::decorated::no>() +
97108
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
98109
N);
99110
}); // parallel for
100-
}).wait();
111+
})
112+
.wait();
101113
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
102114
}
103115

104116
template <typename T, size_t M, size_t N>
105117
void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
106118
const float ref) {
107-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
119+
buffer<T, 2> bufA(A.get_data(), range<2>(M, N));
108120

109121
q.submit([&](handler &cgh) {
110-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
111-
112-
cgh.parallel_for<class mul_matrix>(
113-
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
122+
sycl::accessor accA{bufA, cgh, sycl::read_write};
123+
cgh.parallel_for(
124+
r, [accA](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] {
114125
const auto global_idx = spmd_item.get_global_id(0);
115126
const auto global_idy = spmd_item.get_global_id(1);
116127
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
117128
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
118129

119130
sub_group sg = spmd_item.get_sub_group();
120131
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
121-
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
122-
132+
if constexpr (std::is_same_v<T, bfloat16>)
133+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
134+
else
135+
joint_matrix_fill(sg, sub_a, 5);
123136
auto wi_slice_a =
124137
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
125138
for (int i = 0; i < wi_slice_a.length(); i++) {
126-
wi_slice_a[i] = wi_slice_a[i] * bfloat16(3.0);
139+
if constexpr (std::is_same_v<T, bfloat16>)
140+
wi_slice_a[i] = wi_slice_a[i] * bfloat16(3.0);
141+
else
142+
wi_slice_a[i] = wi_slice_a[i] * 3.0;
127143
}
128144
ext::intel::experimental::matrix::joint_matrix_store(
129145
sg, sub_a,
130146
accA.template get_multi_ptr<access::decorated::no>() +
131147
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
132148
N);
133149
}); // parallel for
134-
}).wait();
150+
})
151+
.wait();
135152
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
136153
}
137154

138155
template <typename T, size_t M, size_t N>
139156
void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
140157
const float ref) {
141-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
158+
buffer<T, 2> bufA(A.get_data(), range<2>(M, N));
142159

143160
q.submit([&](handler &cgh) {
144-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
145-
146-
cgh.parallel_for<class div_matrix>(
147-
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
161+
sycl::accessor accA{bufA, cgh, sycl::read_write};
162+
cgh.parallel_for(
163+
r, [accA](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] {
148164
const auto global_idx = spmd_item.get_global_id(0);
149165
const auto global_idy = spmd_item.get_global_id(1);
150166
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
151167
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
152168

153169
sub_group sg = spmd_item.get_sub_group();
154170
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
155-
156-
joint_matrix_fill(sg, sub_a, bfloat16(4.0));
157-
171+
if constexpr (std::is_same_v<T, bfloat16>)
172+
joint_matrix_fill(sg, sub_a, bfloat16(4.0));
173+
else
174+
joint_matrix_fill(sg, sub_a, 4);
158175
auto wi_slice_a =
159176
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
160177
for (int i = 0; i < wi_slice_a.length(); i++) {
161-
wi_slice_a[i] = wi_slice_a[i] / bfloat16(2.0);
178+
if constexpr (std::is_same_v<T, bfloat16>)
179+
wi_slice_a[i] = wi_slice_a[i] / bfloat16(2.0);
180+
else
181+
wi_slice_a[i] = wi_slice_a[i] / 2.0;
162182
}
163183
ext::intel::experimental::matrix::joint_matrix_store(
164184
sg, sub_a,
165185
accA.template get_multi_ptr<access::decorated::no>() +
166186
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
167187
N);
168188
}); // parallel for
169-
}).wait();
189+
})
190+
.wait();
170191
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
171192
}
172193

173194
template <typename T, size_t M, size_t N>
174195
void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
175196
const float ref) {
176-
buffer<bfloat16, 2> bufA(A.get_data(), range<2>(M, N));
197+
buffer<T, 2> bufA(A.get_data(), range<2>(M, N));
177198

178199
q.submit([&](handler &cgh) {
179-
auto accA = bufA.get_access<access::mode::read_write>(cgh);
180-
cgh.parallel_for<class logic_matrix>(
181-
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
200+
sycl::accessor accA{bufA, cgh, sycl::read_write};
201+
cgh.parallel_for(
202+
r, [accA](nd_item<2> spmd_item)[[intel::reqd_sub_group_size(SG_SZ)]] {
182203
const auto global_idx = spmd_item.get_global_id(0);
183204
const auto global_idy = spmd_item.get_global_id(1);
184205
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
185206
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
186207

187208
sub_group sg = spmd_item.get_sub_group();
188209
joint_matrix<sub_group, T, use::a, TM, TK, layout::row_major> sub_a;
189-
190-
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
191-
210+
if constexpr (std::is_same_v<T, bfloat16>)
211+
joint_matrix_fill(sg, sub_a, bfloat16(5.0));
212+
else
213+
joint_matrix_fill(sg, sub_a, 5);
192214
auto wi_slice_a =
193215
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_a);
194216
for (int i = 0; i < wi_slice_a.length(); i++) {
195217
if (wi_slice_a[i]) {
196-
if (wi_slice_a[i] > bfloat16(2.0) ||
197-
wi_slice_a[i] >= bfloat16(2.0) ||
198-
wi_slice_a[i] < bfloat16(2.0) ||
199-
wi_slice_a[i] <= bfloat16(2.0)) {
200-
T val = (wi_slice_a[i] != bfloat16(2.0)) ? wi_slice_a[i]
201-
: bfloat16(2.0);
202-
val = bfloat16(make_fp32(val) - static_cast<float>(1));
203-
val = bfloat16(make_fp32(val) + static_cast<float>(1));
204-
if (wi_slice_a[i] == bfloat16(2.0)) {
205-
val = bfloat16(make_fp32(val) - static_cast<float>(2));
206-
val = bfloat16(make_fp32(val) * static_cast<float>(3));
207-
val = bfloat16(make_fp32(val) / static_cast<float>(2));
208-
209-
} else {
210-
val = bfloat16(make_fp32(val) + static_cast<float>(2));
218+
if constexpr (std::is_same_v<T, bfloat16>) {
219+
if (wi_slice_a[i] > bfloat16(2.0) ||
220+
wi_slice_a[i] >= bfloat16(2.0) ||
221+
wi_slice_a[i] < bfloat16(2.0) ||
222+
wi_slice_a[i] <= bfloat16(2.0)) {
223+
T val = (wi_slice_a[i] != bfloat16(2.0)) ? wi_slice_a[i]
224+
: bfloat16(2.0);
225+
val = bfloat16(make_fp32(val) - static_cast<float>(1));
226+
val = bfloat16(make_fp32(val) + static_cast<float>(1));
227+
if (wi_slice_a[i] == bfloat16(2.0)) {
228+
val = bfloat16(make_fp32(val) - static_cast<float>(2));
229+
val = bfloat16(make_fp32(val) * static_cast<float>(3));
230+
val = bfloat16(make_fp32(val) / static_cast<float>(2));
231+
232+
} else {
233+
val = bfloat16(make_fp32(val) + static_cast<float>(2));
234+
}
235+
wi_slice_a[i] = val;
236+
}
237+
} else {
238+
if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
239+
wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) {
240+
T val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i]
241+
: static_cast<T>(2.0);
242+
val = val - 1;
243+
val = val + 1;
244+
if (wi_slice_a[i] == 2.0) {
245+
val = val - 2;
246+
val = val * 3;
247+
val = val / 2;
248+
249+
} else {
250+
val = val + 2;
251+
}
252+
wi_slice_a[i] = val;
211253
}
212-
wi_slice_a[i] = val;
213254
}
214255
}
215256
}
@@ -219,7 +260,8 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
219260
(sg_startx * TM) * N + sg_starty / SG_SZ * TN,
220261
N);
221262
}); // parallel for
222-
}).wait();
263+
})
264+
.wait();
223265
assert_ops_ref<T, M, N>(bufA.get_host_access(read_only), ref);
224266
}
225267

@@ -236,21 +278,27 @@ void matrix_ops_ref(float *D, int M, int N) {
236278
}
237279
}
238280

239-
int main() {
281+
template <typename T, typename Tref> int test_ewops() {
240282

241-
big_matrix<float, MATRIX_M, MATRIX_N> MD((float *)&D);
242-
big_matrix<bfloat16, MATRIX_M, MATRIX_N> MA((bfloat16 *)&A);
283+
big_matrix<Tref, MATRIX_M, MATRIX_N> MD((Tref *)&D);
284+
big_matrix<T, MATRIX_M, MATRIX_N> MA((T *)&A);
243285

244286
size_t NDRangeM = MATRIX_M / TM;
245287
size_t NDRangeN = MATRIX_N / TN;
246288
queue q;
247289
nd_range<2> r({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});
248290

249-
matrix_verify_add<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
250-
matrix_verify_sub<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 3.0);
251-
matrix_verify_mul<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 15.0);
252-
matrix_verify_div<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 2.0);
253-
matrix_verify_logic<bfloat16, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
291+
matrix_verify_add<T, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
292+
matrix_verify_sub<T, MATRIX_M, MATRIX_N>(q, MA, r, 3.0);
293+
matrix_verify_mul<T, MATRIX_M, MATRIX_N>(q, MA, r, 15.0);
294+
matrix_verify_div<T, MATRIX_M, MATRIX_N>(q, MA, r, 2.0);
295+
matrix_verify_logic<T, MATRIX_M, MATRIX_N>(q, MA, r, 7.0);
254296

255297
return 0;
256298
}
299+
300+
int main() {
301+
test_ewops<bfloat16, float>();
302+
test_ewops<float, float>();
303+
return 0;
304+
}

0 commit comments

Comments
 (0)