Skip to content

[Matrix][SYCL] Rename wi_slice with wi_data #5728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -548,19 +548,19 @@ While this provides fast element indexing on the GPU compared to the non-restric
However using the `mma` ptx instructions as opposed to the `wmma` ptx instructions the mapping is known. Knowing this mapping is important for the user to implement new operations like sum of rows of a matrix for quantized algorithms.

#### proposal: Explicit conversion in the interface from SIMD to SPMD
We introduce a new function `get_wi_slice` that provides any portion of the matrix that the user wants but in a SPMD array object:.
We introduce a new function `get_wi_data` that provides any portion of the matrix that the user wants but in a SPMD array object:.

```c++
namespace sycl::ext::oneapi::experimental::matrix {
template <typename Group, typename T, size_t NumRows, size_t NumCols, matrix_layout L>
marray<T, n_rows * n_cols> get_wi_slice(joint_matrix<T, NumRows, NumCols, L, Group> &m, size_t row_index,
marray<T, n_rows * n_cols> get_wi_data(joint_matrix<T, NumRows, NumCols, L, Group> &m, size_t row_index,
size_t col_index, size_t n_rows, size_t n_cols);
}
```

Example where each WI gets 1 column:
```c++
marray<T,msize> wi_C = get_wi_slice(C, 0, wi_idx, msize, 1, matrix_layout::row_major);
marray<T,msize> wi_C = get_wi_data(C, 0, wi_idx, msize, 1, matrix_layout::row_major);
for (int i = 0; i < msize; i++)
row_sum += wi_C[i];
```
Expand All @@ -582,7 +582,7 @@ We did not utilize this extension for this matrix API version because sub-group
-- Yes, this will be addressed in the next revision where `use` argument will be introduced to distinguish between right (B) , left (A), and accumulator matrix.
- Ronan Keryell: "It would be interesting to investigate whether providing also member functions would simplify the API. Provide both so it is possible to use the best one for each use case, while waiting for https://en.wikipedia.org/wiki/Uniform_Function_Call_Syntax to land into C++?"

- In the future looking APIs, `get_wi_slice` (that is currently under design) returns an owned object. Should this return a view object to make sure the original matrix C is changed after its slices are modified.
- In the future looking APIs, `get_wi_data` (that is currently under design) returns an owned object. Should this return a view object to make sure the original matrix C is changed after its slices are modified.

## TODO List
- Add support for fill matrix and element-wise operations features
Expand Down
10 changes: 5 additions & 5 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-jit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ template <int D> struct spv_scope_traits<sycl::group<D>> {
template <typename T, size_t NumRows, size_t NumCols,
matrix_layout Layout = matrix_layout::row_major,
typename Group = sycl::sub_group>
class wi_slice;
class wi_data;

template <typename T, size_t NumRows, size_t NumCols,
matrix_layout Layout = matrix_layout::row_major,
Expand All @@ -64,9 +64,9 @@ struct joint_matrix {
#endif // __SYCL_DEVICE_ONLY__
}

inline __SYCL_ALWAYS_INLINE wi_slice<T, NumRows, NumCols, Layout, Group>
inline __SYCL_ALWAYS_INLINE wi_data<T, NumRows, NumCols, Layout, Group>
get_wi_data() {
return wi_slice<T, NumRows, NumCols, Layout, Group>(*this);
return wi_data<T, NumRows, NumCols, Layout, Group>(*this);
}
};

Expand Down Expand Up @@ -455,11 +455,11 @@ class wi_element<uint16_t, NumRows, NumCols, Layout, Group> {

template <typename T, size_t NumRows, size_t NumCols, matrix_layout Layout,
typename Group>
class wi_slice {
class wi_data {
joint_matrix<T, NumRows, NumCols, Layout, Group> &M;

public:
wi_slice(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat) : M(Mat) {}
wi_data(joint_matrix<T, NumRows, NumCols, Layout, Group> &Mat) : M(Mat) {}
size_t length() {
#ifdef __SYCL_DEVICE_ONLY__
return __spirv_JointMatrixWorkItemLengthINTEL(M.spvm);
Expand Down
6 changes: 3 additions & 3 deletions sycl/test/matrix/matrix-elemwise-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
N * 4, matrix_layout::packed_b);
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
}
auto wi_slice_c = sub_c.get_wi_data();
for (int i = 0; i < wi_slice_c.length(); i++) {
wi_slice_c[i] *= 2;
auto wi_data_c = sub_c.get_wi_data();
for (int i = 0; i < wi_data_c.length(); i++) {
wi_data_c[i] *= 2;
}
joint_matrix_store(sg, sub_c,
accC.get_pointer() + (sg_startx * TM) * N +
Expand Down