-
Notifications
You must be signed in to change notification settings - Fork 787
[SYCL][Matrix] syntax changes as preparation before moving joint matrix from experimental namespace #11215
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SYCL][Matrix] syntax changes as preparation before moving joint matrix from experimental namespace #11215
Changes from all commits
b68aead
5fbb285
bf6cd56
b399041
dae1ec6
4ec8360
a461cbb
5ff715b
8ad7da9
26ea49d
a09a778
821fa89
a3921b5
ef1bc67
f395199
c71fee6
8f2f197
1411376
95df3b1
fb1afdc
11df531
a29e8f3
a821107
3f1b575
1d091de
1e20968
1fe7fcd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -40,7 +40,8 @@ struct joint_matrix { | |
|
||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols, Layout> | ||
mutable sycl::ext::oneapi::detail::joint_matrix_cuda<T, Use, Rows, Cols, | ||
Layout> | ||
cuda_impl; | ||
#elif defined(__SPIR__) | ||
__spv::__spirv_JointMatrixINTEL< | ||
|
@@ -61,19 +62,8 @@ struct joint_matrix { | |
} | ||
#ifdef __SYCL_DEVICE_ONLY__ | ||
#if defined(__SPIR__) | ||
// Generate a non-trivial assignment operator and copy c'tor that prevents | ||
// memcpy from being generated. | ||
// TODO: to remove, when either IGC can handle alloca JointMatrix or | ||
// combination of InstCombine + SROA + mem2reg can remove it | ||
joint_matrix(const joint_matrix &other) { | ||
spvm = other.spvm; | ||
return *this; | ||
} | ||
|
||
joint_matrix &operator=(const joint_matrix &rhs) { | ||
spvm = rhs.spvm; | ||
return *this; | ||
} | ||
joint_matrix(const joint_matrix &other) = delete; | ||
joint_matrix &operator=(const joint_matrix &rhs) = delete; | ||
#endif // defined(__SPIR__) | ||
#endif | ||
}; | ||
|
@@ -97,10 +87,6 @@ class wi_data { | |
size_t length() { | ||
#if defined(__NVPTX__) | ||
return jm.cuda_impl.wi_marray.size(); | ||
#else | ||
throw runtime_error("get_wi_data is available using: " | ||
"ext::intel::experimental::matrix::get_wi_data.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif | ||
}; | ||
|
||
|
@@ -109,9 +95,6 @@ class wi_data { | |
return (jm.cuda_impl.wi_marray[i]); | ||
#else | ||
std::ignore = i; | ||
throw runtime_error("get_wi_data is available using: " | ||
"ext::intel::experimental::matrix::get_wi_data.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif | ||
}; | ||
}; | ||
YuriPlyakhin marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -139,9 +122,8 @@ template <typename Group, typename T, use Use, size_t Rows, size_t Cols, | |
__SYCL2020_DEPRECATED("get_wi_data() is deprecated for CUDA backend. Please " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should remove this. This is not really deprecated as joint_matrix is experimental so we can just remove APIs. Deprecated means they still exist and implementations maintain them. In the case of get_wi_data. it is replaced by joint_matrix_apply There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Was this addressed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will be addressed by @JackAKirk among other CUDA changes in a separate PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes I will make this change as soon as this PR is merged. |
||
"use joint_matrix_apply() instead.") | ||
#else | ||
__attribute__((unavailable( | ||
"get_wi_data can't be used on intel device, please use " | ||
"sycl::ext::intel::experimental::matrix::get_wi_data instead!"))) | ||
__attribute__((unavailable("get_wi_data() has been removed from the API and " | ||
"replaced with joint_matrix_apply!"))) | ||
#endif | ||
dkhaldi marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#endif | ||
inline __SYCL_ALWAYS_INLINE decltype(auto) | ||
|
@@ -177,7 +159,7 @@ joint_matrix_apply(Group sg, joint_matrix<Group, T, Use, M, N, Layout> &jm, | |
using storage_element_type = | ||
typename oneapi::detail::jm_type_interpretation_helper_trait< | ||
T>::storage_element_type; | ||
auto wi_data_c = sycl::ext::intel::experimental::matrix::get_wi_data(sg, jm); | ||
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, jm); | ||
for (int i = 0; i < wi_data_c.length(); i++) { | ||
storage_element_type element = wi_data_c[i]; | ||
lambda(element); | ||
|
@@ -260,7 +242,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load( | |
Ptr, stride, __spv::MatrixLayout::ColumnMajor, | ||
spv_scope_traits<Group>::value); | ||
break; | ||
case sycl::ext::intel::experimental::matrix::layout::packed: | ||
case layout::ext_intel_packed: | ||
res.spvm = __spirv_JointMatrixLoadINTEL< | ||
DecorT, S, NumRows, NumCols, | ||
spv_matrix_use_traits<use::accumulator>::value, | ||
|
@@ -322,8 +304,9 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols, | |
access::address_space Space, access::decorated IsDecorated> | ||
inline __SYCL_ALWAYS_INLINE void joint_matrix_store( | ||
Group, | ||
joint_matrix<Group, T, use::accumulator, NumRows, NumCols, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &src, | ||
const joint_matrix<Group, T, use::accumulator, NumRows, NumCols, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
&src, | ||
multi_ptr<T, Space, IsDecorated> dst, size_t stride, | ||
sycl::ext::oneapi::experimental::matrix::layout Layout) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
|
@@ -355,7 +338,7 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( | |
Ptr, src.spvm, stride, __spv::MatrixLayout::ColumnMajor, | ||
spv_scope_traits<Group>::value); | ||
break; | ||
case sycl::ext::intel::experimental::matrix::layout::packed: | ||
case layout::ext_intel_packed: | ||
__spirv_JointMatrixStoreINTEL< | ||
DecorT, T, NumRows, NumCols, | ||
spv_matrix_use_traits<use::accumulator>::value, | ||
|
@@ -375,51 +358,77 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store( | |
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M, | ||
std::size_t K, std::size_t N, layout LayoutA, layout LayoutB> | ||
inline __SYCL_ALWAYS_INLINE | ||
joint_matrix<Group, Tc, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
joint_matrix_mad( | ||
Group, joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A, | ||
joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B, | ||
joint_matrix<Group, Tc, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
&C) { | ||
template <typename Group, typename Ta, typename Tb, typename Tc, typename Td, | ||
std::size_t M, std::size_t K, std::size_t N, layout LayoutA, | ||
layout LayoutB> | ||
inline __SYCL_ALWAYS_INLINE void joint_matrix_mad( | ||
Group, | ||
joint_matrix<Group, Td, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> &D, | ||
const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A, | ||
const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B, | ||
const joint_matrix<Group, Tc, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
&C) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
if constexpr (std::is_same<Ta, Tb>::value) { | ||
joint_matrix<Group, Tc, use::accumulator, M, N, | ||
sycl::ext::oneapi::experimental::matrix::layout::dynamic> | ||
D; | ||
sycl::ext::oneapi::detail::joint_matrix_mad_cuda<Ta, Tc, M, K, N, LayoutA, | ||
LayoutB>( | ||
D.cuda_impl, A.cuda_impl, B.cuda_impl, C.cuda_impl); | ||
return D; | ||
} else { | ||
assert(false && "Ta != Tb : In the CUDA backend joint_matrix_mad " | ||
"requires that joint_matrix data types Ta and Tb match"); | ||
} | ||
#else | ||
joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> res; | ||
if constexpr (std::is_same<Ta, uint16_t>::value && | ||
std::is_same<Tb, uint16_t>::value && | ||
std::is_same<Tc, float>::value) | ||
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
else if constexpr (std::is_unsigned<Ta>::value && std::is_unsigned<Tb>::value) | ||
res.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
D.spvm = __spirv_JointMatrixUUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
else if constexpr (std::is_signed<Ta>::value && std::is_unsigned<Tb>::value) | ||
res.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
D.spvm = __spirv_JointMatrixSUMadINTEL(A.spvm, B.spvm, C.spvm); | ||
else if constexpr (std::is_unsigned<Ta>::value && std::is_signed<Tb>::value) | ||
res.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); | ||
D.spvm = __spirv_JointMatrixUSMadINTEL(A.spvm, B.spvm, C.spvm); | ||
else | ||
res.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
return res; | ||
D.spvm = __spirv_JointMatrixMadINTEL(A.spvm, B.spvm, C.spvm); | ||
#endif // defined(__NVPTX__) | ||
#else | ||
std::ignore = A; | ||
std::ignore = B; | ||
std::ignore = C; | ||
std::ignore = D; | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
} | ||
|
||
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols, | ||
use Use1, use Use2, layout Layout1, layout Layout2> | ||
void joint_matrix_copy( | ||
Group sg, joint_matrix<Group, T1, Use1, Rows, Cols, Layout1> &src, | ||
joint_matrix<Group, T2, Use2, Rows, Cols, Layout2> &dst) { | ||
#if defined(__SYCL_DEVICE_ONLY__) | ||
#if defined(__NVPTX__) | ||
std::ignore = sg; | ||
for (int i = 0; i < src.cuda_impl.wi_marray.size(); i++) { | ||
dst.cuda_impl.wi_marray[i] = src.cuda_impl.wi_marray[i]; | ||
} | ||
#else | ||
using storage_element_type = | ||
typename oneapi::detail::jm_type_interpretation_helper_trait< | ||
T2>::storage_element_type; | ||
auto wi_data_c = sycl::ext::oneapi::detail::get_wi_data(sg, src); | ||
auto wi_data_dst = sycl::ext::oneapi::detail::get_wi_data(sg, dst); | ||
for (int i = 0; i < wi_data_c.length(); i++) { | ||
wi_data_dst[i] = static_cast<storage_element_type>(wi_data_c[i]); | ||
} | ||
#endif // defined(__NVPTX__) | ||
#else | ||
std::ignore = sg; | ||
std::ignore = dst; | ||
std::ignore = src; | ||
throw runtime_error("joint matrix is not supported on host device.", | ||
PI_ERROR_INVALID_DEVICE); | ||
#endif // defined(__SYCL_DEVICE_ONLY__) | ||
|
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: we should modify cuda's code since we add "const" qualifier to jm_store's jm. i added mutable here as a temp fix/
the mutable keyword is used to indicate that a particular member variable of a class is allowed to be modified even when the class object is considered const.
@JackAKirk , will you take care of cuda's changes since "const" qualifier is added to jm_store's jm?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this addition of
mutable
really required? the cuda implementation ofjoint_matrix_store
does not modify thejoint_matrix
argument.Can you give some more details why you think this change is required? Thanks
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
actually, since passing a const variable to a non-const parameter is incorrect, we can't pass const join_matrix's member join_matrix_cuda(it is also const ) to joint_matrix_store_cuda, https://godbolt.org/z/5jz8z94Ga
if i add "const " to joint_matrix_store_cuda's src, there is other issues too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand but I will check out this branch and see what errors there are.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I understand the issue now. The fix is simple but requires quite a lot of repetitive code changes. I think it is fine to do as you suggest and use
mutable
temporarily.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can make a follow up PR to fix it correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On HIP it works without the need for mutable.
42e0c62
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cuda requires reinterpret casts between signed / unsigned builtins for load/store, this is the origin of the problem. I don't think it affects AMD since amd doesn't have builtins for load/store
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there might be some similarities, Initially it worked only with mutable. A number of changes needed to make it work without mutable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh yeah sure. I will make follow up changes for cuda after this is merged too.