Skip to content

Commit 646db9c

Browse files
authored
[SYCL][Matrix] Fix checked matrix instructions (#13287)
There were incorrectly named and had incorrect operands. See #12497 --------- Signed-off-by: Sidorov, Dmitry <[email protected]>
1 parent 83adbdf commit 646db9c

File tree

2 files changed

+44
-50
lines changed

2 files changed

+44
-50
lines changed

sycl/include/CL/__spirv/spirv_ops.hpp

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,32 +52,32 @@ template <typename T, typename Tp, std::size_t R, std::size_t C,
5252
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
5353
extern __DPCPP_SYCL_EXTERNAL
5454
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
55-
__spirv_CooperativeMatrixConstructCheckedINTEL(
56-
const T Value, uint32_t Height, size_t Stride, uint32_t Width,
57-
int32_t CoordX, int32_t CoordY);
55+
__spirv_CooperativeMatrixConstructCheckedINTEL(int32_t CoordX,
56+
int32_t CoordY,
57+
uint32_t Height,
58+
uint32_t Width,
59+
const T Value);
5860

5961
template <typename T, typename Tp, std::size_t R, std::size_t C,
6062
__spv::MatrixUse U,
6163
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
6264
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
6365
extern __DPCPP_SYCL_EXTERNAL
6466
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *
65-
__spirv_JointMatrixLoadCheckedINTEL(T *Ptr, std::size_t Stride,
66-
uint32_t Height, uint32_t Width,
67-
int32_t CoordX, int32_t CoordY,
68-
__spv::MatrixLayout Layout = L,
69-
__spv::Scope::Flag Sc = S,
70-
int MemOperand = 0);
67+
__spirv_CooperativeMatrixLoadCheckedINTEL(
68+
T *Ptr, int32_t CoordX, int32_t CoordY, __spv::MatrixLayout Layout = L,
69+
uint32_t Height = 0, uint32_t Width = 0, std::size_t Stride = 0,
70+
int MemOperand = 0);
7171

7272
template <typename T, typename Tp, std::size_t R, std::size_t C,
7373
__spv::MatrixUse U,
7474
__spv::MatrixLayout L = __spv::MatrixLayout::RowMajor,
7575
__spv::Scope::Flag S = __spv::Scope::Flag::Subgroup>
76-
extern __DPCPP_SYCL_EXTERNAL void __spirv_JointMatrixStoreCheckedINTEL(
77-
T *Ptr, __spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
78-
std::size_t Stride, uint32_t Height, uint32_t Width, int32_t CoordX,
79-
int32_t CoordY, __spv::MatrixLayout Layout = L, __spv::Scope::Flag Sc = S,
80-
int MemOperand = 0);
76+
extern __DPCPP_SYCL_EXTERNAL void __spirv_CooperativeMatrixStoreCheckedINTEL(
77+
T *Ptr, int32_t CoordX, int32_t CoordY,
78+
__spv::__spirv_JointMatrixINTEL<Tp, R, C, L, S, U> *Object,
79+
__spv::MatrixLayout Layout = L, uint32_t Height = 0, uint32_t Width = 0,
80+
std::size_t Stride = 0, int MemOperand = 0);
8181

8282
template <typename TA, typename TB, typename TC, std::size_t M, std::size_t K,
8383
std::size_t N, __spv::MatrixUse UA, __spv::MatrixUse UB,

sycl/include/sycl/ext/oneapi/matrix/matrix-intel.hpp

Lines changed: 30 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
612612
layout Layout, typename T2>
613613
inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(
614614
Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
615-
const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX,
615+
const T2 &Value, size_t Height, size_t Width, size_t CoordX,
616616
size_t CoordY) {
617617
#if defined(__SYCL_DEVICE_ONLY__)
618618
using storage_element_type =
@@ -622,12 +622,10 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(
622622
storage_element_type, T, NumRows, NumCols,
623623
spv_matrix_use_traits<Use>::value,
624624
spv_matrix_layout_traits<Layout>::value>(
625-
static_cast<storage_element_type>(Value), Stride, Height, Width, CoordX,
626-
CoordY);
625+
CoordX, CoordY, Height, Width, static_cast<storage_element_type>(Value));
627626
#else
628627
std::ignore = Res;
629628
std::ignore = Value;
630-
std::ignore = Stride;
631629
std::ignore = Height;
632630
std::ignore = Width;
633631
std::ignore = CoordX;
@@ -654,13 +652,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
654652
std::ignore = sg;
655653
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
656654
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
657-
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
655+
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
658656
DecorT, S, NumRows, NumCols,
659657
spv_matrix_use_traits<use::accumulator>::value,
660658
spv_matrix_layout_traits<layout::dynamic>::value>(
661-
Ptr, Stride, Height, Width, CoordX, CoordY,
662-
sycl::detail::joint_matrix_layout_to_spv(Layout),
663-
spv_scope_traits<Group>::value);
659+
Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
660+
Height, Width, Stride);
664661
#else
665662
std::ignore = sg;
666663
std::ignore = Res;
@@ -694,11 +691,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
694691
std::ignore = sg;
695692
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
696693
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
697-
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
694+
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
698695
DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
699696
spv_matrix_layout_traits<Layout>::value>(
700-
Ptr, Stride, Height, Width, CoordX, CoordY,
701-
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
697+
Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
698+
Width, Stride);
702699
#else
703700
std::ignore = sg;
704701
std::ignore = Res;
@@ -727,13 +724,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
727724
std::ignore = sg;
728725
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
729726
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
730-
__spirv_JointMatrixStoreCheckedINTEL<
727+
__spirv_CooperativeMatrixStoreCheckedINTEL<
731728
DecorT, T, NumRows, NumCols,
732729
spv_matrix_use_traits<use::accumulator>::value,
733730
spv_matrix_layout_traits<layout::dynamic>::value>(
734-
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
735-
sycl::detail::joint_matrix_layout_to_spv(Layout),
736-
spv_scope_traits<Group>::value);
731+
Ptr, CoordX, CoordY, Src.spvm,
732+
sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
737733
#else
738734
std::ignore = sg;
739735
std::ignore = Src;
@@ -763,11 +759,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
763759
std::ignore = sg;
764760
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
765761
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
766-
__spirv_JointMatrixStoreCheckedINTEL<DecorT, Tp, NumRows, NumCols,
767-
spv_matrix_use_traits<Use>::value,
768-
spv_matrix_layout_traits<Layout>::value>(
769-
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
770-
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
762+
__spirv_CooperativeMatrixStoreCheckedINTEL<
763+
DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
764+
spv_matrix_layout_traits<Layout>::value>(
765+
Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
766+
Height, Width, Stride);
771767
#else
772768
std::ignore = sg;
773769
std::ignore = Src;
@@ -797,12 +793,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
797793
#if defined(__SYCL_DEVICE_ONLY__)
798794
std::ignore = sg;
799795
T *Ptr = Src.get();
800-
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
796+
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
801797
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
802798
spv_matrix_layout_traits<layout::dynamic>::value>(
803-
Ptr, Stride, Height, Width, CoordX, CoordY,
804-
sycl::detail::joint_matrix_layout_to_spv(Layout),
805-
spv_scope_traits<Group>::value);
799+
Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv(Layout),
800+
Height, Width, Stride);
806801
#else
807802
std::ignore = sg;
808803
std::ignore = Res;
@@ -832,11 +827,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
832827
#if defined(__SYCL_DEVICE_ONLY__)
833828
std::ignore = sg;
834829
T *Ptr = Src.get();
835-
Res.spvm = __spirv_JointMatrixLoadCheckedINTEL<
830+
Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL<
836831
T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
837832
spv_matrix_layout_traits<Layout>::value>(
838-
Ptr, Stride, Height, Width, CoordX, CoordY,
839-
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
833+
Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height,
834+
Width, Stride);
840835
#else
841836
std::ignore = sg;
842837
std::ignore = Res;
@@ -863,12 +858,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
863858
#if defined(__SYCL_DEVICE_ONLY__)
864859
std::ignore = sg;
865860
T *Ptr = Dst.get();
866-
__spirv_JointMatrixStoreCheckedINTEL<
861+
__spirv_CooperativeMatrixStoreCheckedINTEL<
867862
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
868863
spv_matrix_layout_traits<layout::dynamic>::value>(
869-
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
870-
sycl::detail::joint_matrix_layout_to_spv(Layout),
871-
spv_scope_traits<Group>::value);
864+
Ptr, CoordX, CoordY, Src.spvm,
865+
sycl::detail::joint_matrix_layout_to_spv(Layout), Height, Width, Stride);
872866
#else
873867
std::ignore = sg;
874868
std::ignore = Src;
@@ -894,11 +888,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
894888
#if defined(__SYCL_DEVICE_ONLY__)
895889
std::ignore = sg;
896890
T *Ptr = Dst.get();
897-
__spirv_JointMatrixStoreCheckedINTEL<T, Tp, NumRows, NumCols,
898-
spv_matrix_use_traits<Use>::value,
899-
spv_matrix_layout_traits<Layout>::value>(
900-
Ptr, Src.spvm, Stride, Height, Width, CoordX, CoordY,
901-
spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value);
891+
__spirv_CooperativeMatrixStoreCheckedINTEL<
892+
T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
893+
spv_matrix_layout_traits<Layout>::value>(
894+
Ptr, CoordX, CoordY, Src.spvm, spv_matrix_layout_traits<Layout>::value,
895+
Height, Width, Stride);
902896
#else
903897
std::ignore = sg;
904898
std::ignore = Src;

0 commit comments

Comments
 (0)