@@ -172,7 +172,7 @@ SPDX-License-Identifier: MIT
172
172
/* not supported, fallthrough */
173
173
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , stride_opt ) \
174
174
/* not supported, fallthrough */
175
- #define IMPLEMENT_BLOCK2D_STORE (element_type , contrib_type , contrib_bitwidth , M , K , vec ) \
175
+ #define IMPLEMENT_BLOCK2D_STORE (element_type , contrib_type , contrib_bitwidth , M , K ) \
176
176
/* not supported, fallthrough */
177
177
178
178
#define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , stride_opt ) \
@@ -216,7 +216,7 @@ SPDX-License-Identifier: MIT
216
216
*(__private OUT_VEC##M(u##contrib_type) *)dst = res; \
217
217
return;
218
218
219
- #define IMPLEMENT_BLOCK2D_STORE_SG16 (element_type , contrib_type , contrib_bitwidth , M , K , vec ) \
219
+ #define IMPLEMENT_BLOCK2D_STORE_SG16 (element_type , contrib_type , contrib_bitwidth , M , K ) \
220
220
long offset = as_long(mem); \
221
221
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
222
222
int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
@@ -470,7 +470,7 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, COL_M
470
470
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
471
471
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8 \
472
472
) { \
473
- IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K, src ) \
473
+ IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K) \
474
474
} \
475
475
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL && stride == stride_opt \
476
476
&& (M == 2 || M == 4 || M == 8) && order == _ROW_MAJOR \
@@ -955,10 +955,17 @@ DEFINE_B_B_16x64(local)
955
955
#define DEFINE_STORE_IMPL_LARGE (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , stride_opt , address_space ) \
956
956
INLINE void MANGLE_STORE_NAME_ ##address_space (layout, sg, elem_bitwidth, shape, M) (char *mem, __private char *src, long stride) { \
957
957
int sg_size = get_sub_group_size(); \
958
- if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
959
- && order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8 \
960
- ) { \
961
- IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K, src) \
958
+ if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && M == 16 \
959
+ && order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8) { \
960
+ __private char *c0 = src + 0 * 8 * (sizeof (int)); \
961
+ __private char *c1 = src + 1 * 8 * (sizeof (int)); \
962
+ \
963
+ char *mem0 = mem; \
964
+ char *mem1 = mem + 8 * (sizeof (int)) * stride; \
965
+ \
966
+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem0, c0, stride); \
967
+ __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem1, c1, stride); \
968
+ return; \
962
969
} \
963
970
contrib_type *ptr = (contrib_type *)mem; \
964
971
int slid = get_sub_group_local_id(); \
0 commit comments