Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

Commit 7bb961a

Browse files
authored
[SYCL][Matrix] Add missing explicit SG size statement (#764)
1 parent 1349fce commit 7bb961a

File tree

2 files changed

+121
-117
lines changed

2 files changed

+121
-117
lines changed

SYCL/Matrix/element_wise_all_ops_half.cpp

Lines changed: 119 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -59,26 +59,27 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
5959
q.submit([&](handler &cgh) {
6060
auto accA = bufA.get_access<access::mode::read_write>(cgh);
6161

62-
cgh.parallel_for<class add_matrix>(r, [accA](nd_item<2> spmd_item) {
63-
const auto global_idx = spmd_item.get_global_id(0);
64-
const auto global_idy = spmd_item.get_global_id(1);
65-
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
66-
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
67-
68-
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
69-
joint_matrix<T, TM, TK> sub_a(sg);
70-
71-
joint_matrix_fill(sg, sub_a, 5.0);
72-
73-
auto wi_slice_a = sub_a.get_wi_data();
74-
for (int i = 0; i < wi_slice_a.length(); i++) {
75-
wi_slice_a[i] = wi_slice_a[i] + 2;
76-
}
77-
joint_matrix_store(sg, sub_a,
78-
accA.get_pointer() + (sg_startx * TM) * N +
79-
sg_starty / SG_SZ * TN,
80-
N, matrix_layout::row_major);
81-
}); // parallel for
62+
cgh.parallel_for<class add_matrix>(
63+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
64+
const auto global_idx = spmd_item.get_global_id(0);
65+
const auto global_idy = spmd_item.get_global_id(1);
66+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
67+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
68+
69+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
70+
joint_matrix<T, TM, TK> sub_a(sg);
71+
72+
joint_matrix_fill(sg, sub_a, 5.0);
73+
74+
auto wi_slice_a = sub_a.get_wi_data();
75+
for (int i = 0; i < wi_slice_a.length(); i++) {
76+
wi_slice_a[i] = wi_slice_a[i] + 2;
77+
}
78+
joint_matrix_store(sg, sub_a,
79+
accA.get_pointer() + (sg_startx * TM) * N +
80+
sg_starty / SG_SZ * TN,
81+
N, matrix_layout::row_major);
82+
}); // parallel for
8283
}).wait();
8384
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
8485
}
@@ -91,26 +92,27 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
9192
q.submit([&](handler &cgh) {
9293
auto accA = bufA.get_access<access::mode::read_write>(cgh);
9394

94-
cgh.parallel_for<class sub_matrix>(r, [accA](nd_item<2> spmd_item) {
95-
const auto global_idx = spmd_item.get_global_id(0);
96-
const auto global_idy = spmd_item.get_global_id(1);
97-
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
98-
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
99-
100-
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
101-
joint_matrix<T, TM, TK> sub_a(sg);
102-
103-
joint_matrix_fill(sg, sub_a, 5.0);
104-
105-
auto wi_slice_a = sub_a.get_wi_data();
106-
for (int i = 0; i < wi_slice_a.length(); i++) {
107-
wi_slice_a[i] = wi_slice_a[i] - 2;
108-
}
109-
joint_matrix_store(sg, sub_a,
110-
accA.get_pointer() + (sg_startx * TM) * N +
111-
sg_starty / SG_SZ * TN,
112-
N, matrix_layout::row_major);
113-
}); // parallel for
95+
cgh.parallel_for<class sub_matrix>(
96+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
97+
const auto global_idx = spmd_item.get_global_id(0);
98+
const auto global_idy = spmd_item.get_global_id(1);
99+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
100+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
101+
102+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
103+
joint_matrix<T, TM, TK> sub_a(sg);
104+
105+
joint_matrix_fill(sg, sub_a, 5.0);
106+
107+
auto wi_slice_a = sub_a.get_wi_data();
108+
for (int i = 0; i < wi_slice_a.length(); i++) {
109+
wi_slice_a[i] = wi_slice_a[i] - 2;
110+
}
111+
joint_matrix_store(sg, sub_a,
112+
accA.get_pointer() + (sg_startx * TM) * N +
113+
sg_starty / SG_SZ * TN,
114+
N, matrix_layout::row_major);
115+
}); // parallel for
114116
}).wait();
115117
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
116118
}
@@ -123,26 +125,27 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
123125
q.submit([&](handler &cgh) {
124126
auto accA = bufA.get_access<access::mode::read_write>(cgh);
125127

126-
cgh.parallel_for<class mul_matrix>(r, [accA](nd_item<2> spmd_item) {
127-
const auto global_idx = spmd_item.get_global_id(0);
128-
const auto global_idy = spmd_item.get_global_id(1);
129-
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
130-
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
131-
132-
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
133-
joint_matrix<T, TM, TK> sub_a(sg);
134-
135-
joint_matrix_fill(sg, sub_a, 5.0);
136-
137-
auto wi_slice_a = sub_a.get_wi_data();
138-
for (int i = 0; i < wi_slice_a.length(); i++) {
139-
wi_slice_a[i] = wi_slice_a[i] * 3.0;
140-
}
141-
joint_matrix_store(sg, sub_a,
142-
accA.get_pointer() + (sg_startx * TM) * N +
143-
sg_starty / SG_SZ * TN,
144-
N, matrix_layout::row_major);
145-
}); // parallel for
128+
cgh.parallel_for<class mul_matrix>(
129+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
130+
const auto global_idx = spmd_item.get_global_id(0);
131+
const auto global_idy = spmd_item.get_global_id(1);
132+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
133+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
134+
135+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
136+
joint_matrix<T, TM, TK> sub_a(sg);
137+
138+
joint_matrix_fill(sg, sub_a, 5.0);
139+
140+
auto wi_slice_a = sub_a.get_wi_data();
141+
for (int i = 0; i < wi_slice_a.length(); i++) {
142+
wi_slice_a[i] = wi_slice_a[i] * 3.0;
143+
}
144+
joint_matrix_store(sg, sub_a,
145+
accA.get_pointer() + (sg_startx * TM) * N +
146+
sg_starty / SG_SZ * TN,
147+
N, matrix_layout::row_major);
148+
}); // parallel for
146149
}).wait();
147150
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
148151
}
@@ -155,26 +158,27 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
155158
q.submit([&](handler &cgh) {
156159
auto accA = bufA.get_access<access::mode::read_write>(cgh);
157160

158-
cgh.parallel_for<class div_matrix>(r, [accA](nd_item<2> spmd_item) {
159-
const auto global_idx = spmd_item.get_global_id(0);
160-
const auto global_idy = spmd_item.get_global_id(1);
161-
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
162-
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
163-
164-
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
165-
joint_matrix<T, TM, TK> sub_a(sg);
166-
167-
joint_matrix_fill(sg, sub_a, 4.0);
168-
169-
auto wi_slice_a = sub_a.get_wi_data();
170-
for (int i = 0; i < wi_slice_a.length(); i++) {
171-
wi_slice_a[i] = wi_slice_a[i] / 2.0;
172-
}
173-
joint_matrix_store(sg, sub_a,
174-
accA.get_pointer() + (sg_startx * TM) * N +
175-
sg_starty / SG_SZ * TN,
176-
N, matrix_layout::row_major);
177-
}); // parallel for
161+
cgh.parallel_for<class div_matrix>(
162+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
163+
const auto global_idx = spmd_item.get_global_id(0);
164+
const auto global_idy = spmd_item.get_global_id(1);
165+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
166+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
167+
168+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
169+
joint_matrix<T, TM, TK> sub_a(sg);
170+
171+
joint_matrix_fill(sg, sub_a, 4.0);
172+
173+
auto wi_slice_a = sub_a.get_wi_data();
174+
for (int i = 0; i < wi_slice_a.length(); i++) {
175+
wi_slice_a[i] = wi_slice_a[i] / 2.0;
176+
}
177+
joint_matrix_store(sg, sub_a,
178+
accA.get_pointer() + (sg_startx * TM) * N +
179+
sg_starty / SG_SZ * TN,
180+
N, matrix_layout::row_major);
181+
}); // parallel for
178182
}).wait();
179183
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
180184
}
@@ -187,42 +191,43 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
187191
q.submit([&](handler &cgh) {
188192
auto accA = bufA.get_access<access::mode::read_write>(cgh);
189193

190-
cgh.parallel_for<class logic_matrix>(r, [accA](nd_item<2> spmd_item) {
191-
const auto global_idx = spmd_item.get_global_id(0);
192-
const auto global_idy = spmd_item.get_global_id(1);
193-
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
194-
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
195-
196-
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
197-
joint_matrix<T, TM, TK> sub_a(sg);
198-
199-
joint_matrix_fill(sg, sub_a, 5.0);
200-
201-
auto wi_slice_a = sub_a.get_wi_data();
202-
for (int i = 0; i < wi_slice_a.length(); i++) {
203-
if (wi_slice_a[i]) {
204-
if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
205-
wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) {
206-
T val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i]
207-
: static_cast<half>(2.0);
208-
val--;
209-
val++;
210-
if (wi_slice_a[i] == 2.0) {
211-
val -= 2;
212-
val *= 3.0;
213-
val /= 2.0;
214-
} else {
215-
val += 2;
194+
cgh.parallel_for<class logic_matrix>(
195+
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
196+
const auto global_idx = spmd_item.get_global_id(0);
197+
const auto global_idy = spmd_item.get_global_id(1);
198+
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
199+
const auto sg_starty = global_idy - spmd_item.get_local_id(1);
200+
201+
ext::oneapi::sub_group sg = spmd_item.get_sub_group();
202+
joint_matrix<T, TM, TK> sub_a(sg);
203+
204+
joint_matrix_fill(sg, sub_a, 5.0);
205+
206+
auto wi_slice_a = sub_a.get_wi_data();
207+
for (int i = 0; i < wi_slice_a.length(); i++) {
208+
if (wi_slice_a[i]) {
209+
if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
210+
wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) {
211+
T val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i]
212+
: static_cast<half>(2.0);
213+
val--;
214+
val++;
215+
if (wi_slice_a[i] == 2.0) {
216+
val -= 2;
217+
val *= 3.0;
218+
val /= 2.0;
219+
} else {
220+
val += 2;
221+
}
222+
wi_slice_a[i] = val;
223+
}
216224
}
217-
wi_slice_a[i] = val;
218225
}
219-
}
220-
}
221-
joint_matrix_store(sg, sub_a,
222-
accA.get_pointer() + (sg_startx * TM) * N +
223-
sg_starty / SG_SZ * TN,
224-
N, matrix_layout::row_major);
225-
}); // parallel for
226+
joint_matrix_store(sg, sub_a,
227+
accA.get_pointer() + (sg_startx * TM) * N +
228+
sg_starty / SG_SZ * TN,
229+
N, matrix_layout::row_major);
230+
}); // parallel for
226231
}).wait();
227232
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
228233
}

SYCL/Matrix/element_wise_ops.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
5959

6060
cgh.parallel_for<class imatrix>(
6161
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
62-
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)
63-
64-
{
62+
[accA, accB, accC, M, N,
63+
K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
6564
// The submatrix API has to be accessed by all the workitems in a
6665
// subgroup these functions will be called once by the subgroup no
6766
// code divergence between the workitems

0 commit comments

Comments
 (0)