Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[SYCL][Matrix] Add missing explicit SG size statement #764

Merged
merged 3 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 119 additions & 114 deletions SYCL/Matrix/element_wise_all_ops_half.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,26 +59,27 @@ void matrix_verify_add(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class add_matrix>(r, [accA](nd_item<2> spmd_item) {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] + 2;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
cgh.parallel_for<class add_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] + 2;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
}
Expand All @@ -91,26 +92,27 @@ void matrix_verify_sub(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class sub_matrix>(r, [accA](nd_item<2> spmd_item) {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] - 2;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
cgh.parallel_for<class sub_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] - 2;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
}
Expand All @@ -123,26 +125,27 @@ void matrix_verify_mul(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class mul_matrix>(r, [accA](nd_item<2> spmd_item) {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] * 3.0;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
cgh.parallel_for<class mul_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] * 3.0;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
}
Expand All @@ -155,26 +158,27 @@ void matrix_verify_div(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class div_matrix>(r, [accA](nd_item<2> spmd_item) {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 4.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] / 2.0;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
cgh.parallel_for<class div_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 4.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
wi_slice_a[i] = wi_slice_a[i] / 2.0;
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
}
Expand All @@ -187,42 +191,43 @@ void matrix_verify_logic(queue q, big_matrix<T, M, N> &A, nd_range<2> &r,
q.submit([&](handler &cgh) {
auto accA = bufA.get_access<access::mode::read_write>(cgh);

cgh.parallel_for<class logic_matrix>(r, [accA](nd_item<2> spmd_item) {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
if (wi_slice_a[i]) {
if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) {
T val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i]
: static_cast<half>(2.0);
val--;
val++;
if (wi_slice_a[i] == 2.0) {
val -= 2;
val *= 3.0;
val /= 2.0;
} else {
val += 2;
cgh.parallel_for<class logic_matrix>(
r, [accA](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();
joint_matrix<T, TM, TK> sub_a(sg);

joint_matrix_fill(sg, sub_a, 5.0);

auto wi_slice_a = sub_a.get_wi_data();
for (int i = 0; i < wi_slice_a.length(); i++) {
if (wi_slice_a[i]) {
if (wi_slice_a[i] > 2.0 || wi_slice_a[i] >= 2.0 ||
wi_slice_a[i] < 2.0 || wi_slice_a[i] <= 2.0) {
T val = (wi_slice_a[i] != 2.0) ? wi_slice_a[i]
: static_cast<half>(2.0);
val--;
val++;
if (wi_slice_a[i] == 2.0) {
val -= 2;
val *= 3.0;
val /= 2.0;
} else {
val += 2;
}
wi_slice_a[i] = val;
}
}
wi_slice_a[i] = val;
}
}
}
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
joint_matrix_store(sg, sub_a,
accA.get_pointer() + (sg_startx * TM) * N +
sg_starty / SG_SZ * TN,
N, matrix_layout::row_major);
}); // parallel for
}).wait();
assert_ops_ref<T, M, N>(bufA.get_access<access::mode::read>(), ref);
}
Expand Down
5 changes: 2 additions & 3 deletions SYCL/Matrix/element_wise_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,

cgh.parallel_for<class imatrix>(
nd_range<2>({NDRangeM, NDRangeN * SG_SZ}, {1, 1 * SG_SZ}),
[accA, accB, accC, M, N, K](nd_item<2> spmd_item)

{
[accA, accB, accC, M, N,
K](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
// The submatrix API has to be accessed by all the workitems in a
// subgroup these functions will be called once by the subgroup no
// code divergence between the workitems
Expand Down