Skip to content

[SYCL][Matrix] Add initial get_coord API #7851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
Expand Down
28 changes: 28 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,20 @@ class wi_element {
Group, T, Use, NumRows, NumCols, Layout> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
#if defined(__SYCL_DEVICE_ONLY__)
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator T() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down Expand Up @@ -171,6 +185,20 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
Layout> &Mat,
std::size_t i)
: M(Mat), idx(i) {}

inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord() {
#if defined(__SYCL_DEVICE_ONLY__)
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(M.spvm, idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

operator sycl::ext::oneapi::bfloat16() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_VectorExtractDynamic(M.spvm, idx);
Expand Down
235 changes: 235 additions & 0 deletions sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out

// Kernel B sum by col
#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16

#define TN SG_SZ
#define TK 32

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t M, size_t N>
void sum_cols_ref(host_accessor<T, 2, access::mode::read_write> B,
host_accessor<int, 1, access::mode::read_write> sum_cols) {
int sum_cols_ref[N] = {0};
for (size_t j = 0; j < N; j++) {
for (size_t i = 0; i < M; i++) {
sum_cols_ref[j] += B[i][j];
}
auto diff = sum_cols[j] - sum_cols_ref[j];
assert(std::fabs(static_cast<int>(diff)) <=
std::numeric_limits<int>::epsilon());
}
}

// clang-format off
/*
Here is a demonstration of how matrix B will be divided across
work items for this test case.
< --------------- 128 ---------------------------------->
x x x x x x x x x x x x x x x x .......... x x x x x x ^
x x x x x x x x x x x x x x x x .......... x x x x x x 16
x x x x x x x x x x x x x x x x .......... x x x x x x |
..... |
x x x x x x x x x x x x x x x x .......... x x x x x x |
x x x x x x x x x x x x x x x x .......... x x x x x x v


--------------- 64 ---------------->
x x x x x x .......... x x x x x x ^
x x x x x x .......... x x x x x x 8
x x x x x x .......... x x x x x x | <-- part of (VNNI-ed)
..... | original matrix each SG
x x x x x x .......... x x x x x x | holds
x x x x x x .......... x x x x x x v
< WI0 > < WI15 >


<-------- 16 ------------->
x x x .......... x x x ^
x x x .......... x x x |
x x x .......... x x x | <-- part of (non-VNNI-ed) original matrix
..... | each SG holds
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x 32
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x v

If we dividie the above matrix across 16 (SG_SZ) work items,
each WI will hold 32 elements. And these 32 elements will be
8x4 chunks as shown in the VNNI-ed matrix figure.
*/

// The total distribution among the WIs in ALL the sub-groups is as follows:
// This is useful to figure out the the global index is to be calculated

/*
W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements
wi [0,0] --> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | wi [0,16] --> i=0, [0, 64]
i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | i=1, [0, 65]
i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | i=2, [0, 66]
i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | i=3, [0, 67]

i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | ....
i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] |
i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] |
i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] |
... ... .... |
i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | i=28, [7, 124]
i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | i=29, [7, 125]
i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | i=30, [7, 126]
i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | i=31, [7, 127]
---------------------------------------------------------------------------------------- ---------------------------
wi [1,0] --> i=0, [8, 0]
i=1, [8, 1]
i=2, [8, 2]
i=3, [8, 2]
...
i=28, [15, 0]
i=29, [15, 1]
i=30, [15, 2]
i=31, [15, 3]
*/

// The following is the distribution among WIs in a SINGLE SG.
/*
W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements

wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] |
i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] |
i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] |
i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] |

i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] |
i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] |
i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] |
i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] |
... ... .... |
i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] |
i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] |
i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] |
i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] |

*/
// clang-format on

template <typename T, size_t M, size_t N>
void matrix_sum_cols(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
buffer<int8_t, 2> bufB(B.get_data(), range<2>(M, N));
// size of vector is known because SG size of set by the user in this case
int sum_cols[N] = {0};
buffer<int> sum_cols_v(sum_cols, N); // there are total of tK/4 * 2, 16 rows
q.submit([&](handler &cgh) {
auto accB = bufB.get_access<access::mode::read_write>(cgh);

auto v = sum_cols_v.get_access<access::mode::atomic>(cgh);
auto os = sycl::stream(100000, 6144, cgh);

cgh.parallel_for<class add_matrix>(
r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();

// TK = 32, TN = 16
joint_matrix<sub_group, int8_t, use::b, TK, TN,
ext::intel::experimental::matrix::layout::packed>
sub_b;

joint_matrix_load(sg, sub_b,
accB.get_pointer() + (global_idx * (TK / 4) * N) +
sg_starty / SG_SZ * TN * 4,
N);

int32_t sum_local_cols[N] = {0}; // 4 local cols, N total
// sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row
auto wiData =
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);

size_t
global_index; // Index into the result array that holds the sums.

// Keep track of cols handled in this WI
int32_t handled_cols[N] = {-1};

// each WI calculates local sum of cols
for (int i = 0; i < wiData.length(); ++i) {
// get the index of the element in the submatrix
auto dataItem = wiData[i];
auto [row, col] = dataItem.get_coord();

// Calculation of global index
int sg_idx = (int)global_idy / SG_SZ;
global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ;
sum_local_cols[global_index] += wiData[i];
handled_cols[global_index] = 1;
}

for (int j = 0; j < N; j++) {
if (handled_cols[j] == 1) {
global_index = j;
sum_local_cols[global_index] = reduce_over_group(
sg, sum_local_cols[global_index], sycl::plus<>());
atomic_fetch_add(v[global_index], sum_local_cols[global_index]);
}
}
}); // parallel for
}).wait();
sum_cols_ref<T, M, N>(bufB.get_host_access(), sum_cols_v.get_host_access());
}

// TK = 32, TN = 16
static constexpr size_t MATRIX_K = TK / 4 * 2; // 16
static constexpr size_t MATRIX_N = TN * 4 * 2; // 128
int8_t B[MATRIX_K][MATRIX_N];

/* < --------------- 128 ---------------------------------->
x x x x x x x x x x x x x x x x .......... x x x x x x ^
x x x x x x x x x x x x x x x x .......... x x x x x x 16
x x x x x x x x x x x x x x x x .......... x x x x x x |
..... |
x x x x x x x x x x x x x x x x .......... x x x x x x |
x x x x x x x x x x x x x x x x .......... x x x x x x v
*/
int main() {
big_matrix<int8_t, MATRIX_K, MATRIX_N> MB((int8_t *)&B);

size_t NDRangeK = MATRIX_K / (TK / 4);
size_t NDRangeN = (MATRIX_N / 4) / TN;
queue q;
nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});

for (int i = 0; i < MATRIX_K; i++) {
for (int j = 0; j < MATRIX_N; j++) {
B[i][j] = i;
}
}

matrix_sum_cols<int8_t, MATRIX_K, MATRIX_N>(q, MB, r);

std::cout << "Passed\n";

return 0;
}