@@ -352,6 +352,8 @@ DEFINE_LOAD(PackedA_RowMajor, _SG16, char, 8, short, 16, 2, 32, 2x32, ROW_MAJOR,
352
352
353
353
/* A load tf32 SG16 */
354
354
DEFINE_LOAD (PackedA_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 8 , 8 x8 , ROW_MAJOR , , 4 , 8 )
355
+ /* A load tf32 SG16 for sub group size 32*/
356
+ DEFINE_LOAD (PackedA_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 8 , 8 x8 , ROW_MAJOR , , 2 , 8 )
355
357
356
358
/* PackedB load i16 */
357
359
DEFINE_LOAD (PackedB_ColumnMajor , , short , 16 , int , 32 , 8 , 16 , 16 x8 , COL_MAJOR , , 8 , -1 )
@@ -574,6 +576,8 @@ DEFINE_STORE(PackedA_RowMajor, _SG16, short, 16, short, 16, 8, 16, 8x16, ROW_MAJ
574
576
575
577
/* A store tf32 SG16 */
576
578
DEFINE_STORE (PackedA_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 8 , 8 x8 , ROW_MAJOR , , 4 , 8 , false)
579
+ /* A store tf32 SG16 for sub group size 32*/
580
+ DEFINE_STORE (PackedA_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 8 , 8 x8 , ROW_MAJOR , , 2 , 8 , false)
577
581
578
582
/* PackedB store i16*/
579
583
DEFINE_STORE (PackedB_ColumnMajor , , short , 16 , int , 32 , 8 , 16 , 16 x8 , COL_MAJOR , , 8 , -1 , false)
@@ -601,6 +605,9 @@ DEFINE_STORE(PackedB_PackedB, _SG16, char, 8, int, 32, 8, 64, 32x16, ROW_MAJ
601
605
/* B store tf32 SG16 */
602
606
DEFINE_STORE (PackedB_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 8 , 16 , true)
603
607
608
+ /* B store tf32 SG16 for sub group size 32 */
609
+ DEFINE_STORE (PackedB_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 4 , 16 , true)
610
+
604
611
/* Acc i32 */
605
612
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , ROW_MAJOR , , 8 , 8 , true)
606
613
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 7 , 8 , 7 x8 , ROW_MAJOR , , 7 , 8 , true)
@@ -874,7 +881,7 @@ DEFINE_LOAD_LARGE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 16, 16, 16x16,
874
881
DEFINE_LOAD_LARGE (PackedA_RowMajor , _SG16 , short , 16 , short , 16 , 16 , 16 , 16 x16 , ROW_MAJOR , , 16 )
875
882
876
883
#define DEFINE_ACC_ROW_MAJOR_32x64 (address_space ) \
877
- INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_32_ ##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
884
+ INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_ ##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
878
885
__private char *c0 = dst + 0 * 16 * (sizeof (int)); \
879
886
__private char *c1 = dst + 1 * 16 * (sizeof (int)); \
880
887
__private char *c2 = dst + 2 * 16 * (sizeof (int)); \
@@ -924,7 +931,7 @@ DEFINE_A_ROW_MAJOR_32x16(global)
924
931
DEFINE_A_ROW_MAJOR_32x16 (local )
925
932
926
933
#define DEFINE_B_B_16x64 (address_space ) \
927
- INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_8_ ##address_space ##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
934
+ INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_32_ ##address_space ##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
928
935
__private char *b0 = dst; \
929
936
__private char *b1 = dst + 1 * 16 * (sizeof (short)); \
930
937
__private char *b2 = dst + 2 * 16 * (sizeof (short)); \
@@ -973,7 +980,7 @@ DEFINE_B_B_16x64(local)
973
980
DEFINE_STORE_LARGE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 16 , 16 , 16 x16 , ROW_MAJOR , , 16 )
974
981
975
982
#define DEFINE_STORE_ACC_ROW_MAJOR_32x64 (address_space ) \
976
- INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_32_ ##address_space##_pi64_v8i8(char *mem, __private char *src, long stride) { \
983
+ INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_ ##address_space##_pi64_v8i8(char *mem, __private char *src, long stride) { \
977
984
__private char *c0 = src + 0 * 16 * (sizeof (int)); \
978
985
__private char *c1 = src + 1 * 16 * (sizeof (int)); \
979
986
__private char *c2 = src + 2 * 16 * (sizeof (int)); \
0 commit comments