Skip to content

[SYCL][Matrix] Move elementwise operation under intel namespace and add joint_matrix_apply. #8417

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,27 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
layout Layout>
struct joint_matrix;

template <typename T, size_t NumRows, size_t NumCols, use Use,
layout Layout = layout::dynamic, typename Group = sycl::sub_group>
} // namespace matrix
} // namespace experimental
} // namespace oneapi

namespace intel::experimental::matrix {

// Begin wi_element definition

template <typename T, size_t NumRows, size_t NumCols,
sycl::ext::oneapi::experimental::matrix::use Use,
sycl::ext::oneapi::experimental::matrix::layout Layout =
sycl::ext::oneapi::experimental::matrix::layout::dynamic,
typename Group = sycl::sub_group>
class wi_element {
joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &M;
sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, NumRows,
NumCols, Layout> &M;
std::size_t idx;

public:
wi_element(joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Mat,
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, T, Use, NumRows, NumCols, Layout> &Mat,
std::size_t i)
: M(Mat), idx(i) {}
operator T() {
Expand Down Expand Up @@ -142,17 +155,20 @@ class wi_element {
#undef OP
};

template <size_t NumRows, size_t NumCols, use Use, layout Layout,
template <size_t NumRows, size_t NumCols,
sycl::ext::oneapi::experimental::matrix::use Use,
sycl::ext::oneapi::experimental::matrix::layout Layout,
typename Group>
class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
Group> {
joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
Layout> &M;
sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols, Layout> &M;
std::size_t idx;

public:
wi_element(joint_matrix<Group, sycl::ext::oneapi::bfloat16, Use, NumRows,
NumCols, Layout> &Mat,
wi_element(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, sycl::ext::oneapi::bfloat16, Use, NumRows, NumCols,
Layout> &Mat,
std::size_t i)
: M(Mat), idx(i) {}
operator sycl::ext::oneapi::bfloat16() {
Expand Down Expand Up @@ -290,11 +306,57 @@ class wi_element<sycl::ext::oneapi::bfloat16, NumRows, NumCols, Use, Layout,
#endif // __SYCL_DEVICE_ONLY__
};

} // namespace matrix
} // namespace experimental
} // namespace oneapi
// End wi_element definition

// Begin wi_data definition

template <typename Group, typename T,
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout>
class wi_data {

sycl::ext::oneapi::experimental::matrix::joint_matrix<Group, T, Use, Rows,
Cols, Layout> &jm;

wi_data(sycl::ext::oneapi::experimental::matrix::joint_matrix<
Group, T, Use, Rows, Cols, Layout> &_jm)
: jm(_jm){};

template <typename Grp, typename Type,
sycl::ext::oneapi::experimental::matrix::use UseJm, size_t NumRows,
size_t NumCols,
sycl::ext::oneapi::experimental::matrix::layout LayoutJm>
friend decltype(auto)
get_wi_data(Grp, sycl::ext::oneapi::experimental::matrix::joint_matrix<
Grp, Type, UseJm, NumRows, NumCols, LayoutJm> &);

public:
size_t length() {
#if __SYCL_DEVICE_ONLY__
return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
};

decltype(auto) operator[](size_t i) {
return wi_element<T, Rows, Cols, Use, Layout, Group>(jm, i);
};
};

template <typename Group, typename T,
sycl::ext::oneapi::experimental::matrix::use Use, size_t Rows,
size_t Cols, sycl::ext::oneapi::experimental::matrix::layout Layout>
inline __SYCL_ALWAYS_INLINE decltype(auto)
get_wi_data(Group sg, sycl::ext::oneapi::experimental::matrix::joint_matrix<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this also should only work for NVPTX path

Group, T, Use, Rows, Cols, Layout> &jm) {
std::ignore = sg;
return wi_data(jm);
}

// End wi_data definition

namespace intel::experimental::matrix {
template <
typename Group, typename T,
sycl::ext::oneapi::experimental::matrix::use Use, size_t NumRows,
Expand Down
46 changes: 43 additions & 3 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-unified.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,19 @@ class wi_data {
#if defined(__NVPTX__)
return jm.cuda_impl.wi_marray.size();
#else
return __spirv_JointMatrixWorkItemLengthINTEL(jm.spvm);
throw runtime_error("get_wi_data is available using: "
"ext::intel::experimental::matrix::get_wi_data.",
PI_ERROR_INVALID_DEVICE);
#endif
};

decltype(auto) operator[](size_t i) {
#if defined(__NVPTX__)
return (jm.cuda_impl.wi_marray[i]);
#else
return wi_element<T, Rows, Cols, Use, Layout, Group>(jm, i);
throw runtime_error("get_wi_data is available using: "
"ext::intel::experimental::matrix::get_wi_data.",
PI_ERROR_INVALID_DEVICE);
#endif
};
};
Expand All @@ -94,8 +98,18 @@ template <typename type, size_t size> class wi_data {

template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
layout Layout>
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
__SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please "
"use joint_matrix_apply() instead.")
#else
__attribute__((unavailable(
"get_wi_data can't be used on intel device, please use "
"sycl::ext::intel::experimental::matrix::get_wi_data instead!")))
#endif
#endif
inline __SYCL_ALWAYS_INLINE decltype(auto)
get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
#if defined(__SYCL_DEVICE_ONLY__)
std::ignore = sg;
return wi_data(jm);
Expand All @@ -112,6 +126,32 @@ get_wi_data(Group sg, joint_matrix<Group, T, Use, Rows, Cols, Layout> &jm) {
#endif // defined(__SYCL_DEVICE_ONLY__)
}

template <typename Group, typename T, use Use, size_t M, size_t N,
layout Layout, typename F>
inline __SYCL_ALWAYS_INLINE void
joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm,
F &&lambda) {
#if defined(__SYCL_DEVICE_ONLY__)
#if defined(__NVPTX__)
std::ignore = sg;
for (int i = 0; i < jm.cuda_impl.wi_marray.size(); i++) {
lambda(jm.cuda_impl.wi_marray[i]);
}
#else // NVPTX
auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm);
for (int i = 0; i < wi_data_c.length(); i++) {
T element = wi_data_c[i];
lambda(element);
wi_data_c[i] = element;
}
#endif
#else
throw runtime_error("joint matrix is not supported on host device.",
PI_ERROR_INVALID_DEVICE);
#endif
return;
}

template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
layout Layout, typename T2>
inline __SYCL_ALWAYS_INLINE void
Expand Down
3 changes: 2 additions & 1 deletion sycl/test/matrix/matrix-elemwise-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void matrix_multiply(big_matrix<T1, NUM_ROWS_C, NUM_COLS_C> &C,
N * 4);
sub_c = joint_matrix_mad(sg, sub_a, sub_b, sub_c);
}
auto wi_data_c = get_wi_data(sg, sub_c);
auto wi_data_c =
sycl::ext::intel::experimental::matrix::get_wi_data(sg, sub_c);
for (int i = 0; i < wi_data_c.length(); i++) {
wi_data_c[i] *= 2;
}
Expand Down