Skip to content

[SYCL][Matrix] Add initial get_coord API #7851

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions sycl/include/CL/__spirv/spirv_ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
extern __DPCPP_SYCL_EXTERNAL __spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *
__spirv_CompositeConstruct(const T v);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
extern __DPCPP_SYCL_EXTERNAL __ocl_vec_t<uint32_t, 2>
__spirv_JointMatrixGetElementCoordINTEL(
__spv::__spirv_JointMatrixINTEL<T, R, C, L, S, U> *, size_t i);

template <typename T, std::size_t R, std::size_t C, __spv::MatrixUse U,
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
Expand Down
43 changes: 43 additions & 0 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ class wi_element {
NumCols, Layout> &M;
std::size_t idx;

template <typename T1, size_t NRows, size_t NCols,
sycl::ext::oneapi::experimental::matrix::use Use1,
sycl::ext::oneapi::experimental::matrix::layout Layout1,
typename Grp>
friend std::tuple<uint32_t, uint32_t>
get_coord(wi_element<T1, NRows, NCols, Use1, Layout1, Grp> &);

public:
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, T, Use, NumRows, NumCols, Layout> &Mat,
Expand Down Expand Up @@ -165,6 +172,13 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
std::size_t idx;

template <typename T1, size_t NRows, size_t NCols,
sycl::ext::oneapi::experimental::matrix::use Use1,
sycl::ext::oneapi::experimental::matrix::layout Layout1,
typename Grp>
friend std::tuple<uint32_t, uint32_t>
get_coord(wi_element<T1, NRows, NCols, Use1, Layout1, Grp> &);

public:
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
Expand Down Expand Up @@ -308,6 +322,35 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,

// End wi_element definition

template <typename T, size_t NumRows, size_t NumCols,
sycl::ext::oneapi::experimental::matrix::use Use,
sycl::ext::oneapi::experimental::matrix::layout Layout,
typename Group>
inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t>
get_coord(wi_element<T, NumRows, NumCols, Use, Layout, Group> &we) {
#if defined(__SYCL_DEVICE_ONLY__)
__ocl_vec_t<uint32_t, 2> coord =
__spirv_JointMatrixGetElementCoordINTEL(we.M.spvm, we.idx);
const uint32_t row = coord[0];
const uint32_t col = coord[1];
return std::make_tuple(row, col);
#else
std::ignore = we;
throw runtime_error(
"get_coord is only supported on Intel XMX and AMX devices.",
PI_ERROR_INVALID_DEVICE);
#endif // __SYCL_DEVICE_ONLY__
}

// To make host compilation possible, here the argument is not a wi_element
// type, but just base data types e.g. float, int8 etc.
template <typename T>
inline __SYCL_ALWAYS_INLINE std::tuple<uint32_t, uint32_t> get_coord(T &we) {
std::ignore = we;
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
}

// Begin wi_data definition

template <typename Group, typename T,
Expand Down
238 changes: 238 additions & 0 deletions sycl/test/matrix/matrix-bfloat16-test-coord-basicB.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
// RUN: %clangxx -fsycl -O2 -DSYCL_EXT_ONEAPI_MATRIX_VERSION=4 %s -o %t.out

// Kernel B sum by col
#include <iostream>
#include <sycl/sycl.hpp>

using namespace sycl;
using namespace sycl::ext::oneapi::experimental::matrix;

#define SG_SZ 16

#define TN SG_SZ
#define TK 32

template <typename T, size_t NUM_ROWS, size_t NUM_COLS> struct big_matrix {
public:
T *mat;

public:
T *get_data() { return mat; }
void set_data(T *data) { mat = data; }
big_matrix(T *data) : mat(data) {}
};

template <typename T, size_t M, size_t N>
void sum_cols_ref(host_accessor<T, 2, access::mode::read_write> B,
host_accessor<int, 1, access::mode::read_write> sum_cols) {
int sum_cols_ref[N] = {0};
for (size_t j = 0; j < N; j++) {
for (size_t i = 0; i < M; i++) {
sum_cols_ref[j] += B[i][j];
}
auto diff = sum_cols[j] - sum_cols_ref[j];
assert(std::fabs(static_cast<int>(diff)) <=
std::numeric_limits<int>::epsilon());
}
}

// clang-format off
/*
Here is a demonstration of how matrix B will be divided across
work items for this test case.
< --------------- 128 ---------------------------------->
x x x x x x x x x x x x x x x x .......... x x x x x x ^
x x x x x x x x x x x x x x x x .......... x x x x x x 16
x x x x x x x x x x x x x x x x .......... x x x x x x |
..... |
x x x x x x x x x x x x x x x x .......... x x x x x x |
x x x x x x x x x x x x x x x x .......... x x x x x x v


--------------- 64 ---------------->
x x x x x x .......... x x x x x x ^
x x x x x x .......... x x x x x x 8
x x x x x x .......... x x x x x x | <-- part of (VNNI-ed)
..... | original matrix each SG
x x x x x x .......... x x x x x x | holds
x x x x x x .......... x x x x x x v
< WI0 > < WI15 >


<-------- 16 ------------->
x x x .......... x x x ^
x x x .......... x x x |
x x x .......... x x x | <-- part of (non-VNNI-ed) original matrix
..... | each SG holds
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x 32
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x |
x x x .......... x x x v

If we dividie the above matrix across 16 (SG_SZ) work items,
each WI will hold 32 elements. And these 32 elements will be
8x4 chunks as shown in the VNNI-ed matrix figure.
*/

// The total distribution among the WIs in ALL the sub-groups is as follows:
// This is useful to figure out the the global index is to be calculated

/*
W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements
wi [0,0] --> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] | wi [0,16] --> i=0, [0, 64]
i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] | i=1, [0, 65]
i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] | i=2, [0, 66]
i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] | i=3, [0, 67]

i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] | ....
i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] |
i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] |
i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] |
... ... .... |
i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] | i=28, [7, 124]
i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] | i=29, [7, 125]
i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] | i=30, [7, 126]
i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] | i=31, [7, 127]
---------------------------------------------------------------------------------------- ---------------------------
wi [1,0] --> i=0, [8, 0]
i=1, [8, 1]
i=2, [8, 2]
i=3, [8, 2]
...
i=28, [15, 0]
i=29, [15, 1]
i=30, [15, 2]
i=31, [15, 3]
*/

// The following is the distribution among WIs in a SINGLE SG.
/*
W0 --> 0 0 0 0 1 1 1 1 ... 7 7 7 7 --> total 32 elements

wi [0,0] -> i=0, [0, 0] wi [0,1] --> i=0, [0, 4] wi [0,15] --> i=0, [0, 60] |
i=1, [0, 1] i=1, [0, 5] i=1, [0, 61] |
i=2, [0, 2] i=2, [0, 6] i=2, [0, 62] |
i=3, [0, 3] i=3, [0, 7] i=3, [0, 63] |

i=4, [1, 0] i=4, [1, 4] i=4, [1, 60] |
i=5, [1, 1] i=5, [1, 5] i=5, [1, 61] |
i=6, [1, 2] i=6, [1, 6] i=6, [1, 62] |
i=7, [1, 3] i=7, [1, 7] i=7, [1, 63] |
... ... .... |
i=28,[7, 0] i=28,[7, 4] i=28,[7, 60] |
i=29,[7, 1] i=29,[7, 5] i=29,[7, 61] |
i=30,[7, 2] i=30,[7, 6] i=30,[7, 62] |
i=31,[7, 3] i=31,[7, 7] i=31,[7, 63] |

*/
// clang-format on

template <typename T, size_t M, size_t N>
void matrix_sum_cols(queue q, big_matrix<T, M, N> &B, nd_range<2> &r) {
buffer<int8_t, 2> bufB(B.get_data(), range<2>(M, N));
// size of vector is known because SG size of set by the user in this case
int sum_cols[N] = {0};
buffer<int> sum_cols_v(sum_cols, N); // there are total of tK/4 * 2, 16 rows
q.submit([&](handler &cgh) {
auto accB = bufB.get_access<access::mode::read_write>(cgh);

auto v = sum_cols_v.get_access<access::mode::atomic>(cgh);
auto os = sycl::stream(100000, 6144, cgh);

cgh.parallel_for<class add_matrix>(
r, [=](nd_item<2> spmd_item) [[intel::reqd_sub_group_size(SG_SZ)]] {
const auto global_idx = spmd_item.get_global_id(0);
const auto global_idy = spmd_item.get_global_id(1);
const auto sg_startx = global_idx - spmd_item.get_local_id(0);
const auto sg_starty = global_idy - spmd_item.get_local_id(1);

ext::oneapi::sub_group sg = spmd_item.get_sub_group();

// TK = 32, TN = 16
joint_matrix<sub_group, int8_t, use::b, TK, TN,
ext::intel::experimental::matrix::layout::packed>
sub_b;

joint_matrix_load(sg, sub_b,
accB.get_pointer() + (global_idx * (TK / 4) * N) +
sg_starty / SG_SZ * TN * 4,
N);

int32_t sum_local_cols[N] = {0}; // 4 local cols, N total
// sub_b has 32x16 elements, 32 elements per WI, 4 per WI per row
auto wiData =
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_b);

size_t
global_index; // Index into the result array that holds the sums.

// Keep track of cols handled in this WI
int32_t handled_cols[N] = {-1};

// each WI calculates local sum of cols
for (int i = 0; i < wiData.length(); ++i) {
// get the index of the element in the submatrix
auto dataItem = wiData[i];
auto [row, col] =
sycl::ext::intel::experimental::matrix::get_coord(dataItem);

// Calculation of global index
int sg_idx = (int)global_idy / SG_SZ;
global_index = col + sg_idx * 4 /*VNNI_FACTOR*/ * SG_SZ;
sum_local_cols[global_index] += wiData[i];
handled_cols[global_index] = 1;
}

for (int j = 0; j < N; j++) {
if (handled_cols[j] == 1) {
global_index = j;
sum_local_cols[global_index] = reduce_over_group(
sg, sum_local_cols[global_index], sycl::plus<>());
// TODO: Do we need a reduce_over_grp? Adding it does not
// make any difference in result
atomic_fetch_add(v[global_index], sum_local_cols[global_index]);
}
}
}); // parallel for
}).wait();
sum_cols_ref<T, M, N>(bufB.get_host_access(), sum_cols_v.get_host_access());
}

// TK = 32, TN = 16
static constexpr size_t MATRIX_K = TK / 4 * 2; // 16
static constexpr size_t MATRIX_N = TN * 4 * 2; // 128
int8_t B[MATRIX_K][MATRIX_N];

/* < --------------- 128 ---------------------------------->
x x x x x x x x x x x x x x x x .......... x x x x x x ^
x x x x x x x x x x x x x x x x .......... x x x x x x 16
x x x x x x x x x x x x x x x x .......... x x x x x x |
..... |
x x x x x x x x x x x x x x x x .......... x x x x x x |
x x x x x x x x x x x x x x x x .......... x x x x x x v
*/
int main() {
big_matrix<int8_t, MATRIX_K, MATRIX_N> MB((int8_t *)&B);

size_t NDRangeK = MATRIX_K / (TK / 4);
size_t NDRangeN = (MATRIX_N / 4) / TN;
queue q;
nd_range<2> r({NDRangeK, NDRangeN * SG_SZ}, {1, 1 * SG_SZ});

for (int i = 0; i < MATRIX_K; i++) {
for (int j = 0; j < MATRIX_N; j++) {
B[i][j] = i;
}
}

matrix_sum_cols<int8_t, MATRIX_K, MATRIX_N>(q, MB, r);

std::cout << "Passed\n";

return 0;
}