Skip to content

Commit 2c69abc

Browse files
YuriPlyakhinigcbot
authored andcommitted
2d block load/store use simt1, fix JM 16x16 store
Execution mask for 2D Block Load/Store Message should be SIMT1 Also Block Height for 2D Block Store is 8 hence to use 2d block store for matrix 16x16, need to split it to 2d block store instructions
1 parent de4fb2a commit 2c69abc

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ SPDX-License-Identifier: MIT
172172
/* not supported, fallthrough */
173173
#define IMPLEMENT_BLOCK2D_LOAD_VNNI_TX(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, stride_opt) \
174174
/* not supported, fallthrough */
175-
#define IMPLEMENT_BLOCK2D_STORE(element_type, contrib_type, contrib_bitwidth, M, K, vec) \
175+
#define IMPLEMENT_BLOCK2D_STORE(element_type, contrib_type, contrib_bitwidth, M, K) \
176176
/* not supported, fallthrough */
177177

178178
#define IMPLEMENT_BLOCK2D_LOAD_SG16_ROW_MAJOR(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, stride_opt) \
@@ -216,7 +216,7 @@ SPDX-License-Identifier: MIT
216216
*(__private OUT_VEC##M(u##contrib_type) *)dst = res; \
217217
return;
218218

219-
#define IMPLEMENT_BLOCK2D_STORE_SG16(element_type, contrib_type, contrib_bitwidth, M, K, vec) \
219+
#define IMPLEMENT_BLOCK2D_STORE_SG16(element_type, contrib_type, contrib_bitwidth, M, K) \
220220
long offset = as_long(mem); \
221221
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
222222
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
@@ -470,7 +470,7 @@ DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, COL_M
470470
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
471471
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8 \
472472
) { \
473-
IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K, src) \
473+
IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K) \
474474
} \
475475
if (WI_rows == M && BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= VECTOR_CONT_IMPL && stride == stride_opt \
476476
&& (M == 2 || M == 4 || M == 8) && order == _ROW_MAJOR \
@@ -955,10 +955,17 @@ DEFINE_B_B_16x64(local)
955955
#define DEFINE_STORE_IMPL_LARGE(layout, sg, element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, shape, order, us, stride_opt, address_space) \
956956
INLINE void MANGLE_STORE_NAME_##address_space(layout, sg, elem_bitwidth, shape, M) (char *mem, __private char *src, long stride) { \
957957
int sg_size = get_sub_group_size(); \
958-
if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
959-
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8 \
960-
) { \
961-
IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K, src) \
958+
if (BIF_FLAG_CTRL_GET(JointMatrixLoadStoreOpt) >= BLOCK2D_IMPL && M == 16 \
959+
&& order == _ROW_MAJOR && address_space == AS_GLOBAL && elem_bitwidth > 8) { \
960+
__private char *c0 = src + 0 * 8 * (sizeof (int)); \
961+
__private char *c1 = src + 1 * 8 * (sizeof (int)); \
962+
\
963+
char *mem0 = mem; \
964+
char *mem1 = mem + 8 * (sizeof (int)) * stride; \
965+
\
966+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem0, c0, stride); \
967+
__builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_SG16_8x16_i32_8_global_pi64_v8i8(mem1, c1, stride); \
968+
return; \
962969
} \
963970
contrib_type *ptr = (contrib_type *)mem; \
964971
int slid = get_sub_group_local_id(); \

IGC/Compiler/CISACodeGen/CISABuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8830,7 +8830,7 @@ namespace IGC
88308830
LSC_CACHE_OPTS cacheOpts)
88318831
{
88328832
VISA_PredOpnd* predOpnd = GetFlagOperand(m_encoderState.m_flag);
8833-
VISA_Exec_Size execSize = visaExecSize(m_program->m_dispatchSize);
8833+
VISA_Exec_Size execSize = EXEC_SIZE_1;
88348834
VISA_EMask_Ctrl mask = ConvertMaskToVisaType(m_encoderState.m_mask, m_encoderState.m_noMask);
88358835
LSC_DATA_SHAPE_BLOCK2D dataShape2D{};
88368836
dataShape2D.size = LSC_GetElementSize(elemSize, true);

0 commit comments

Comments
 (0)