@@ -108,6 +108,9 @@ SPDX-License-Identifier: MIT
108
108
#define ARR_TO_VEC1 (type , arr ) \
109
109
arr[0]
110
110
111
+ typedef ushort __attribute__ ((ext_vector_type (32 ))) ushort32 ;
112
+
113
+ #define OUT_VEC32 (type ) type##32
111
114
#define OUT_VEC16 (type ) type##16
112
115
#define OUT_VEC8 (type ) type##8
113
116
#define OUT_VEC7 (type ) type##8
@@ -118,6 +121,7 @@ SPDX-License-Identifier: MIT
118
121
#define OUT_VEC2 (type ) type##2
119
122
#define OUT_VEC1 (type ) type
120
123
124
+ #define OUT_STORE_VEC32 (type ) type##32
121
125
#define OUT_STORE_VEC16 (type ) type##16
122
126
#define OUT_STORE_VEC8 (type ) type##8
123
127
#define OUT_STORE_VEC7 (type ) type##8
@@ -245,7 +249,8 @@ SPDX-License-Identifier: MIT
245
249
wi_contrib[i] = as_int((uchar4)(row0, row1, row2, row3)); \
246
250
}
247
251
248
- // variants for 7,6,5,3 and 1 are only used to make the code compilable
252
+ // variants for 32, 16, 7, 6, 5, 3 and 1 are only used to make the code compilable
253
+ #define DEFINE_BLOCK_RW_NAME32 (rw , us ) intel_sub_group_block_##rw##us##32
249
254
#define DEFINE_BLOCK_RW_NAME16 (rw , us ) intel_sub_group_block_##rw##us##16
250
255
#define DEFINE_BLOCK_RW_NAME8 (rw , us ) intel_sub_group_block_##rw##us##8
251
256
#define DEFINE_BLOCK_RW_NAME7 (rw , us ) intel_sub_group_block_##rw##us##8
@@ -364,7 +369,8 @@ SPDX-License-Identifier: MIT
364
369
int sg_size = get_sub_group_size(); \
365
370
/* When M != WI_rows, only scenarios limited to i32 row major are supported */ \
366
371
bool is32bitElemHalfMRowMajor = elem_bitwidth == 32 && WI_rows == M / 2 && order == _ROW_MAJOR ; \
367
- if ((WI_rows == M || is32bitElemHalfMRowMajor ) && BIF_FLAG_CTRL_GET (JointMatrixLoadStoreOpt ) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8 ) \
372
+ if ((WI_rows == M || is32bitElemHalfMRowMajor ) && BIF_FLAG_CTRL_GET (JointMatrixLoadStoreOpt ) >= BLOCK2D_IMPL \
373
+ && (M == 2 || M == 4 || M == 8 || M == 16 || M == 32 ) \
368
374
&& (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32 )) \
369
375
&& address_space == AS_GLOBAL \
370
376
) { \
@@ -374,6 +380,7 @@ SPDX-License-Identifier: MIT
374
380
&& stride == K && (M == 2 || M == 4 || M == 8 ) && order == _ROW_MAJOR \
375
381
&& (address_space == AS_GLOBAL || address_space == AS_LOCAL ) \
376
382
) { \
383
+ OUT_STORE_VEC ##M (u##contrib_type) OVERLOADABLE DEFINE_BLOCK_RW_NAME##M(read, us)(ATTRIBUTE_##address_space u##contrib_type *); \
377
384
OUT_STORE_VEC##M(u##contrib_type) res = DEFINE_BLOCK_RW_NAME##M(read, us)((ATTRIBUTE_##address_space u##contrib_type *)mem); \
378
385
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
379
386
return; \
@@ -979,41 +986,8 @@ INLINE void __builtin_spriv_OpJointMatrixMadINTEL_32x64x16_bf16_bf16_fp32(__priv
979
986
__builtin_spriv_OpJointMatrixMadINTEL_16x16x16_bf16_bf16_fp32 (a1 , b3 , c7 , d7 );
980
987
}
981
988
982
- #define DEFINE_LOAD_IMPL_LARGE (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , address_space ) \
983
- INLINE void MANGLE_LOAD_NAME_##address_space(layout, sg, elem_bitwidth, shape, M) (__private char *dst, char *mem, long stride) { \
984
- int sg_size = get_sub_group_size(); \
985
- if ( BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8 || M == 16) \
986
- && (order == _ROW_MAJOR || order == _VNNI_TX) && address_space == AS_GLOBAL \
987
- ) { \
988
- IMPLEMENT_BLOCK2D_LOAD(sg, order##_, element_type, contrib_type, M, K, M) \
989
- } \
990
- contrib_type *ptr = (contrib_type *)mem; \
991
- int slid = get_sub_group_local_id(); \
992
- int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
993
- stride = stride / pack_factor; \
994
- int sg_cols = K / pack_factor; \
995
- int skip_factor = sg_size / sg_cols; \
996
- __private contrib_type *wi_contrib = (__private contrib_type *)dst; \
997
- for (int i = 0; i < M; i++) { \
998
- if ( (i*skip_factor + slid/sg_cols) < M ) \
999
- wi_contrib[i] = ptr[IND##order(slid, stride, skip_factor, i, sg_cols)]; \
1000
- else \
1001
- wi_contrib[i] = 0; /*last even row for matrix with odd number of rows doesn't exist*/ \
1002
- } \
1003
- }
1004
-
1005
- #define DEFINE_LOAD_LARGE__ (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us ) \
1006
- DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, AS_GENERIC) \
1007
- DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, AS_LOCAL) \
1008
- DEFINE_LOAD_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, _##order, us, AS_GLOBAL)
1009
-
1010
- #define DEFINE_LOAD_LARGE (layout , sg , element_type , contrib_type , M , K , order , us ) \
1011
- DEFINE_LOAD_LARGE__(layout, sg, element_type, BITWIDTH(element_type), contrib_type, BITWIDTH(contrib_type), \
1012
- M, K, SHAPE(layout, M, K, element_type, contrib_type), \
1013
- order, us)
1014
-
1015
- DEFINE_LOAD_LARGE (Accumulator_RowMajor , _SG16 , int , int , 16 , 16 , ROW_MAJOR , )
1016
- DEFINE_LOAD_LARGE (PackedA_RowMajor , _SG16 , short , short , 16 , 16 , ROW_MAJOR , )
989
+ DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , int , 16 , 16 , ROW_MAJOR , , 16 )
990
+ DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 16 , 16 , ROW_MAJOR , , 16 )
1017
991
1018
992
#define DEFINE_ACC_ROW_MAJOR_32x64 (address_space ) \
1019
993
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_##address_space##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
@@ -1049,21 +1023,7 @@ DEFINE_ACC_ROW_MAJOR_32x64(generic)
1049
1023
DEFINE_ACC_ROW_MAJOR_32x64 (global )
1050
1024
DEFINE_ACC_ROW_MAJOR_32x64 (local )
1051
1025
1052
- #define DEFINE_A_ROW_MAJOR_32x16 (address_space ) \
1053
- INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_32x16_i16_32_ ##address_space ##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
1054
- __private char *dst0 = dst; \
1055
- __private char *dst1 = dst + 16 * (sizeof (short)); \
1056
- \
1057
- char *mem0 = mem; \
1058
- char *mem1 = mem + 16 * (sizeof (short)) * stride; \
1059
- \
1060
- __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_16_##address_space##_v8i8_pi32_i32(dst0, mem0, stride); \
1061
- __builtin_spriv_OpJointMatrixLoadINTEL_PackedA_RowMajor_SG16_16x16_i16_16_##address_space##_v8i8_pi32_i32(dst1, mem1, stride); \
1062
- }
1063
-
1064
- DEFINE_A_ROW_MAJOR_32x16 (generic )
1065
- DEFINE_A_ROW_MAJOR_32x16 (global )
1066
- DEFINE_A_ROW_MAJOR_32x16 (local )
1026
+ DEFINE_LOAD (PackedA_RowMajor , _SG16 , short , short , 32 , 16 , ROW_MAJOR , , 32 )
1067
1027
1068
1028
#define DEFINE_B_B_16x64 (address_space ) \
1069
1029
INLINE void __builtin_spriv_OpJointMatrixLoadINTEL_PackedB_PackedB_SG16_16x64_i16_32_ ##address_space ##_v8i8_pi32_i32(__private char *dst, char *mem, long stride) { \
0 commit comments