Skip to content

[SYCL][Matrix] Add initial get_coord API. #7037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
Closed
8 changes: 8 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,14 @@ template <typename T, std::size_t R, std::size_t C,
extern SYCL_EXTERNAL size_t __spirv_JointMatrixWorkItemLengthINTEL(
JOINT_MATRIX_INTEL(T, R, C, L, S, U) *);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(JOINT_MATRIX_INTEL(T, R, C, L, S, U) *,
size_t i);

template <typename T, std::size_t R, std::size_t C,
__spv::MatrixUse U = __spv::MatrixUse::Unnecessary,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
Expand Down
44 changes: 44 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit-use.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <CL/__spirv/spirv_ops.hpp>
#include <sycl/detail/defines_elementary.hpp>
#include <sycl/feature_test.hpp>
#include <tuple>

namespace sycl {
__SYCL_INLINE_VER_NAMESPACE(_V1) {
Expand Down Expand Up @@ -256,6 +257,21 @@ class wi_element {
wi_element(joint_matrix<T, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

// Various Operations
operator T() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -339,6 +355,20 @@ class wi_element<uint16_t, NumRows, NumCols, Use, Layout, Group> {
wi_element(joint_matrix<uint16_t, NumRows, NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator uint16_t() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -489,6 +519,20 @@ class wi_element<sycl::ext::oneapi::experimental::bfloat16, NumRows, NumCols,
NumCols, Use, Layout, Group> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

std::tuple<uint32_t, uint32_t> get_coord() {
#ifdef __SYCL_DEVICE_ONLY__
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator sycl::ext::oneapi::experimental::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down
117 changes: 117 additions & 0 deletions sycl/test/matrix/matrix-bfloat16-test-coord-basic.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=2 %s -o %t.out
// RUN: %t.out
// XFAIL: *

// this code calculates the sum of rows into a global array of number of rows
// elements. First, partial reduction is computed inside each SG, then atomic
// add is used to reduce between SG leaders. The get_coord() API is used for
// retrieving the row

#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16

#define TN SG_SZ
#define TK 32

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t M, size_t N>
void sum_rows_ref(
accessor<T, 2, access::mode::read, access::target::host_buffer> B,
accessor<int, 1, access::mode::read, access::target::host_buffer>
sum_rows) {
int sum_rows_ref[M] = {0};
for (size_t i = 0; i < M; i++) {
for (size_t j = 0; j < N; j++) {
sum_rows_ref[i] += B[i][j];
}
auto diff = sum_rows[i] - sum_rows_ref[i];
assert(std::fabs(static_cast<int>(diff)) <=
std::numeric_limits<int>::epsilon());
}
}

template <typename T, size_t M, size_t N>
void matrix_sum_rows(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
buffer<int8_t, 2> bufB(B.get_data(), range<2>(M, N));
// size of vector is known because SG size of set by the user in this case
int sum_rows[M] = {0};
buffer<int> sum_rows_v(sum_rows, M); // there are total of tK/4 * 2, 16 rows
q.submit([&](handler &cgh) {
auto accB = bufB.get_access<access::mode::read_write>(cgh);

auto v = sum_rows_v.get_access<access::mode::atomic>(cgh);

cgh.parallel_for<class add_matrix>(
r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();

joint_matrix<T, TK, TN, use::b> sub_b(sg);

joint_matrix_load(sg, sub_b,
accB.get_pointer() + (global_idx * (TK / 4) * N) +
sg_starty / SG_SZ * TN * 4,
N, layout::packed_b);

int32_t sum_local_rows[M] = {0};
auto tBData = sub_b.get_wi_data();

// each WI calculates local sum of rows
for (int i = 0; i < tBData.length(); ++i) {
// row and col holds global co_ordinates of the matrix
auto [row, col] = tBData[i].get_coord();
sum_local_rows[row] += tBData[i];

sum_local_rows[row] =
reduce_over_group(sg, sum_local_rows[row], sycl::plus<>());
// only Groups leader perform the global reduction
if (global_idy % SG_SZ == 0) {
atomic_fetch_add(v[row], sum_local_rows[row]);
}
}
}); // parallel for
}).wait();
sum_rows_ref<T, M, N>(bufB.get_access<access::mode::read>(),
sum_rows_v.get_access<access::mode::read>());
}

static constexpr size_t MATRIX_K = TK / 4 * 2;
static constexpr size_t MATRIX_N = TN * 4 * 2;
int8_t B[MATRIX_K][MATRIX_N];

int main() {
big_matrix<int8_t, MATRIX_K, MATRIX_N> MB((int8_t *)&B);

size_t NDRangeK = MATRIX_K / (TK / 4);
size_t NDRangeN = (MATRIX_N / 4) / TN;
queue q;
nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});

for (int i = 0; i < MATRIX_K; i++) {
for (int j = 0; j < MATRIX_N; j++) {
B[i][j] = i;
}
}

matrix_sum_rows<int8_t, MATRIX_K, MATRIX_N>(q, MB, r);

return 0;
}
Loading