@@ -612,7 +612,7 @@ template <typename Group, typename T, size_t NumRows, size_t NumCols, use Use,
612
612
layout Layout, typename T2>
613
613
inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked (
614
614
Group, joint_matrix<Group, T, Use, NumRows, NumCols, Layout> &Res,
615
- const T2 &Value, size_t Stride, size_t Height, size_t Width, size_t CoordX,
615
+ const T2 &Value, size_t Height, size_t Width, size_t CoordX,
616
616
size_t CoordY) {
617
617
#if defined(__SYCL_DEVICE_ONLY__)
618
618
using storage_element_type =
@@ -622,12 +622,10 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_fill_checked(
622
622
storage_element_type, T, NumRows, NumCols,
623
623
spv_matrix_use_traits<Use>::value,
624
624
spv_matrix_layout_traits<Layout>::value>(
625
- static_cast <storage_element_type>(Value), Stride, Height, Width, CoordX,
626
- CoordY);
625
+ CoordX, CoordY, Height, Width, static_cast <storage_element_type>(Value));
627
626
#else
628
627
std::ignore = Res;
629
628
std::ignore = Value;
630
- std::ignore = Stride;
631
629
std::ignore = Height;
632
630
std::ignore = Width;
633
631
std::ignore = CoordX;
@@ -654,13 +652,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
654
652
std::ignore = sg;
655
653
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
656
654
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
657
- Res.spvm = __spirv_JointMatrixLoadCheckedINTEL <
655
+ Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL <
658
656
DecorT, S, NumRows, NumCols,
659
657
spv_matrix_use_traits<use::accumulator>::value,
660
658
spv_matrix_layout_traits<layout::dynamic>::value>(
661
- Ptr, Stride, Height, Width, CoordX, CoordY,
662
- sycl::detail::joint_matrix_layout_to_spv (Layout),
663
- spv_scope_traits<Group>::value);
659
+ Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv (Layout),
660
+ Height, Width, Stride);
664
661
#else
665
662
std::ignore = sg;
666
663
std::ignore = Res;
@@ -694,11 +691,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
694
691
std::ignore = sg;
695
692
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
696
693
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Src);
697
- Res.spvm = __spirv_JointMatrixLoadCheckedINTEL <
694
+ Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL <
698
695
DecorT, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
699
696
spv_matrix_layout_traits<Layout>::value>(
700
- Ptr, Stride, Height, Width, CoordX, CoordY ,
701
- spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value );
697
+ Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height ,
698
+ Width, Stride );
702
699
#else
703
700
std::ignore = sg;
704
701
std::ignore = Res;
@@ -727,13 +724,12 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
727
724
std::ignore = sg;
728
725
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
729
726
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
730
- __spirv_JointMatrixStoreCheckedINTEL <
727
+ __spirv_CooperativeMatrixStoreCheckedINTEL <
731
728
DecorT, T, NumRows, NumCols,
732
729
spv_matrix_use_traits<use::accumulator>::value,
733
730
spv_matrix_layout_traits<layout::dynamic>::value>(
734
- Ptr, Src.spvm , Stride, Height, Width, CoordX, CoordY,
735
- sycl::detail::joint_matrix_layout_to_spv (Layout),
736
- spv_scope_traits<Group>::value);
731
+ Ptr, CoordX, CoordY, Src.spvm ,
732
+ sycl::detail::joint_matrix_layout_to_spv (Layout), Height, Width, Stride);
737
733
#else
738
734
std::ignore = sg;
739
735
std::ignore = Src;
@@ -763,11 +759,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
763
759
std::ignore = sg;
764
760
using DecorT = typename sycl::detail::DecoratedType<T, Space>::type;
765
761
DecorT *Ptr = sycl::detail::getDecorated<DecorT>(Dst);
766
- __spirv_JointMatrixStoreCheckedINTEL<DecorT, Tp, NumRows, NumCols,
767
- spv_matrix_use_traits<Use>::value,
768
- spv_matrix_layout_traits<Layout>::value>(
769
- Ptr, Src. spvm , Stride, Height, Width, CoordX, CoordY ,
770
- spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value );
762
+ __spirv_CooperativeMatrixStoreCheckedINTEL<
763
+ DecorT, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
764
+ spv_matrix_layout_traits<Layout>::value>(
765
+ Ptr, CoordX, CoordY, Src. spvm , spv_matrix_layout_traits<Layout>::value ,
766
+ Height, Width, Stride );
771
767
#else
772
768
std::ignore = sg;
773
769
std::ignore = Src;
@@ -797,12 +793,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
797
793
#if defined(__SYCL_DEVICE_ONLY__)
798
794
std::ignore = sg;
799
795
T *Ptr = Src.get ();
800
- Res.spvm = __spirv_JointMatrixLoadCheckedINTEL <
796
+ Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL <
801
797
T, S, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
802
798
spv_matrix_layout_traits<layout::dynamic>::value>(
803
- Ptr, Stride, Height, Width, CoordX, CoordY,
804
- sycl::detail::joint_matrix_layout_to_spv (Layout),
805
- spv_scope_traits<Group>::value);
799
+ Ptr, CoordX, CoordY, sycl::detail::joint_matrix_layout_to_spv (Layout),
800
+ Height, Width, Stride);
806
801
#else
807
802
std::ignore = sg;
808
803
std::ignore = Res;
@@ -832,11 +827,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_load_checked(
832
827
#if defined(__SYCL_DEVICE_ONLY__)
833
828
std::ignore = sg;
834
829
T *Ptr = Src.get ();
835
- Res.spvm = __spirv_JointMatrixLoadCheckedINTEL <
830
+ Res.spvm = __spirv_CooperativeMatrixLoadCheckedINTEL <
836
831
T, S, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
837
832
spv_matrix_layout_traits<Layout>::value>(
838
- Ptr, Stride, Height, Width, CoordX, CoordY ,
839
- spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value );
833
+ Ptr, CoordX, CoordY, spv_matrix_layout_traits<Layout>::value, Height ,
834
+ Width, Stride );
840
835
#else
841
836
std::ignore = sg;
842
837
std::ignore = Res;
@@ -863,12 +858,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
863
858
#if defined(__SYCL_DEVICE_ONLY__)
864
859
std::ignore = sg;
865
860
T *Ptr = Dst.get ();
866
- __spirv_JointMatrixStoreCheckedINTEL <
861
+ __spirv_CooperativeMatrixStoreCheckedINTEL <
867
862
T, T, NumRows, NumCols, spv_matrix_use_traits<use::accumulator>::value,
868
863
spv_matrix_layout_traits<layout::dynamic>::value>(
869
- Ptr, Src.spvm , Stride, Height, Width, CoordX, CoordY,
870
- sycl::detail::joint_matrix_layout_to_spv (Layout),
871
- spv_scope_traits<Group>::value);
864
+ Ptr, CoordX, CoordY, Src.spvm ,
865
+ sycl::detail::joint_matrix_layout_to_spv (Layout), Height, Width, Stride);
872
866
#else
873
867
std::ignore = sg;
874
868
std::ignore = Src;
@@ -894,11 +888,11 @@ inline __SYCL_ALWAYS_INLINE void joint_matrix_store_checked(
894
888
#if defined(__SYCL_DEVICE_ONLY__)
895
889
std::ignore = sg;
896
890
T *Ptr = Dst.get ();
897
- __spirv_JointMatrixStoreCheckedINTEL<T, Tp, NumRows, NumCols,
898
- spv_matrix_use_traits<Use>::value,
899
- spv_matrix_layout_traits<Layout>::value>(
900
- Ptr, Src. spvm , Stride, Height, Width, CoordX, CoordY ,
901
- spv_matrix_layout_traits<Layout>::value, spv_scope_traits<Group>::value );
891
+ __spirv_CooperativeMatrixStoreCheckedINTEL<
892
+ T, Tp, NumRows, NumCols, spv_matrix_use_traits<Use>::value,
893
+ spv_matrix_layout_traits<Layout>::value>(
894
+ Ptr, CoordX, CoordY, Src. spvm , spv_matrix_layout_traits<Layout>::value ,
895
+ Height, Width, Stride );
902
896
#else
903
897
std::ignore = sg;
904
898
std::ignore = Src;
0 commit comments