Skip to content

Commit b35a931

Browse files
dkhaldibb-sycl
authored andcommitted
[SYCL][Matrix] Correct a test case that redefines a class name (intel#757)
1 parent 0b24c1c commit b35a931

File tree

1 file changed

+121
-0
lines changed

1 file changed

+121
-0
lines changed

SYCL/Matrix/element_wise_all_ops_half.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
5959
q.submit([&](handler &cgh) {
6060
auto accA = bufA.get_access<access::mode::read_write>(cgh);
6161

62+
<<<<<<< HEAD
6263
cgh.parallel_for<class add_matrix>(
6364
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
6465
const auto global_idx = spmd_item.get_global_id(0);
@@ -80,6 +81,28 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
8081
sg_starty / SG_SZ * TN,
8182
N, matrix_layout::row_major);
8283
}); // parallel for
84+
=======
85+
cgh.parallel_for<class add_matrix>(r, [accA](nd_item<2> spmd_item) {
86+
const auto global_idx = spmd_item.get_global_id(0);
87+
const auto global_idy = spmd_item.get_global_id(1);
88+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
89+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
90+
91+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
92+
joint_matrix<T, TM, TK> sub_a(sg);
93+
94+
joint_matrix_fill(sg, sub_a, 5.0);
95+
96+
auto wi_slice_a = sub_a.get_wi_data();
97+
for (int i = 0; i < wi_slice_a.length(); i++) {
98+
wi_slice_a[i] = wi_slice_a[i] + 2;
99+
}
100+
joint_matrix_store(sg, sub_a,
101+
accA.get_pointer() + (sg_startx * TM) * N +
102+
sg_starty / SG_SZ * TN,
103+
N, matrix_layout::row_major);
104+
}); // parallel for
105+
>>>>>>> 62e420f44 ([SYCL][Matrix] Correct a test case that redefines a class name (#757))
83106
}).wait();
84107
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
85108
}
@@ -92,6 +115,7 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
92115
q.submit([&](handler &cgh) {
93116
auto accA = bufA.get_access<access::mode::read_write>(cgh);
94117

118+
<<<<<<< HEAD
95119
cgh.parallel_for<class sub_matrix>(
96120
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
97121
const auto global_idx = spmd_item.get_global_id(0);
@@ -113,6 +137,28 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
113137
sg_starty / SG_SZ * TN,
114138
N, matrix_layout::row_major);
115139
}); // parallel for
140+
=======
141+
cgh.parallel_for<class sub_matrix>(r, [accA](nd_item<2> spmd_item) {
142+
const auto global_idx = spmd_item.get_global_id(0);
143+
const auto global_idy = spmd_item.get_global_id(1);
144+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
145+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
146+
147+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
148+
joint_matrix<T, TM, TK> sub_a(sg);
149+
150+
joint_matrix_fill(sg, sub_a, 5.0);
151+
152+
auto wi_slice_a = sub_a.get_wi_data();
153+
for (int i = 0; i < wi_slice_a.length(); i++) {
154+
wi_slice_a[i] = wi_slice_a[i] - 2;
155+
}
156+
joint_matrix_store(sg, sub_a,
157+
accA.get_pointer() + (sg_startx * TM) * N +
158+
sg_starty / SG_SZ * TN,
159+
N, matrix_layout::row_major);
160+
}); // parallel for
161+
>>>>>>> 62e420f44 ([SYCL][Matrix] Correct a test case that redefines a class name (#757))
116162
}).wait();
117163
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
118164
}
@@ -125,6 +171,7 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
125171
q.submit([&](handler &cgh) {
126172
auto accA = bufA.get_access<access::mode::read_write>(cgh);
127173

174+
<<<<<<< HEAD
128175
cgh.parallel_for<class mul_matrix>(
129176
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
130177
const auto global_idx = spmd_item.get_global_id(0);
@@ -146,6 +193,28 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
146193
sg_starty / SG_SZ * TN,
147194
N, matrix_layout::row_major);
148195
}); // parallel for
196+
=======
197+
cgh.parallel_for<class mul_matrix>(r, [accA](nd_item<2> spmd_item) {
198+
const auto global_idx = spmd_item.get_global_id(0);
199+
const auto global_idy = spmd_item.get_global_id(1);
200+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
201+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
202+
203+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
204+
joint_matrix<T, TM, TK> sub_a(sg);
205+
206+
joint_matrix_fill(sg, sub_a, 5.0);
207+
208+
auto wi_slice_a = sub_a.get_wi_data();
209+
for (int i = 0; i < wi_slice_a.length(); i++) {
210+
wi_slice_a[i] = wi_slice_a[i] * 3.0;
211+
}
212+
joint_matrix_store(sg, sub_a,
213+
accA.get_pointer() + (sg_startx * TM) * N +
214+
sg_starty / SG_SZ * TN,
215+
N, matrix_layout::row_major);
216+
}); // parallel for
217+
>>>>>>> 62e420f44 ([SYCL][Matrix] Correct a test case that redefines a class name (#757))
149218
}).wait();
150219
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
151220
}
@@ -158,6 +227,7 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
158227
q.submit([&](handler &cgh) {
159228
auto accA = bufA.get_access<access::mode::read_write>(cgh);
160229

230+
<<<<<<< HEAD
161231
cgh.parallel_for<class div_matrix>(
162232
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
163233
const auto global_idx = spmd_item.get_global_id(0);
@@ -179,6 +249,28 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
179249
sg_starty / SG_SZ * TN,
180250
N, matrix_layout::row_major);
181251
}); // parallel for
252+
=======
253+
cgh.parallel_for<class div_matrix>(r, [accA](nd_item<2> spmd_item) {
254+
const auto global_idx = spmd_item.get_global_id(0);
255+
const auto global_idy = spmd_item.get_global_id(1);
256+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
257+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
258+
259+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
260+
joint_matrix<T, TM, TK> sub_a(sg);
261+
262+
joint_matrix_fill(sg, sub_a, 4.0);
263+
264+
auto wi_slice_a = sub_a.get_wi_data();
265+
for (int i = 0; i < wi_slice_a.length(); i++) {
266+
wi_slice_a[i] = wi_slice_a[i] / 2.0;
267+
}
268+
joint_matrix_store(sg, sub_a,
269+
accA.get_pointer() + (sg_startx * TM) * N +
270+
sg_starty / SG_SZ * TN,
271+
N, matrix_layout::row_major);
272+
}); // parallel for
273+
>>>>>>> 62e420f44 ([SYCL][Matrix] Correct a test case that redefines a class name (#757))
182274
}).wait();
183275
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
184276
}
@@ -191,6 +283,7 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
191283
q.submit([&](handler &cgh) {
192284
auto accA = bufA.get_access<access::mode::read_write>(cgh);
193285

286+
<<<<<<< HEAD
194287
cgh.parallel_for<class logic_matrix>(
195288
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
196289
const auto global_idx = spmd_item.get_global_id(0);
@@ -221,6 +314,34 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
221314
}
222315
wi_slice_a[i] = val;
223316
}
317+
=======
318+
cgh.parallel_for<class logic_matrix>(r, [accA](nd_item<2> spmd_item) {
319+
const auto global_idx = spmd_item.get_global_id(0);
320+
const auto global_idy = spmd_item.get_global_id(1);
321+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
322+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
323+
324+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
325+
joint_matrix<T, TM, TK> sub_a(sg);
326+
327+
joint_matrix_fill(sg, sub_a, 5.0);
328+
329+
auto wi_slice_a = sub_a.get_wi_data();
330+
for (int i = 0; i < wi_slice_a.length(); i++) {
331+
if (wi_slice_a[i]) {
332+
if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
333+
wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) {
334+
T val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i]
335+
: static_cast<half>(2.0);
336+
val--;
337+
val++;
338+
if (wi_slice_a[i] == 2.0) {
339+
val -= 2;
340+
val *= 3.0;
341+
val /= 2.0;
342+
} else {
343+
val += 2;
344+
>>>>>>> 62e420f44 ([SYCL][Matrix] Correct a test case that redefines a class name (#757))
224345
}
225346
}
226347
joint_matrix_store(sg, sub_a,

0 commit comments

Comments
 (0)