Skip to content

Commit 4417de5

Browse files
fda0igcbot
authored andcommitted
Add JointMatrix column major accumulator loads&stores
Add JointMatrix accumulator i32 transposed loads and stores. Loads use 2D block read with transpose where possible otherwise we fallback to scalar implementation. Stores use scalar implementation.
1 parent 0107b40 commit 4417de5

File tree

4 files changed

+169
-39
lines changed

4 files changed

+169
-39
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
4242
#define ATTRIBUTE_AS_GLOBAL __global
4343

4444
// Index for row major layout is calculated based on that sub group size may be
45-
// bigger than N. Currently we do not have any test case
46-
// of col_major matrix for such case, therefore we keep the implementation of
47-
// COL_MAJOR simple.
45+
// bigger than N.
4846
// Arguments:
4947
// sg_cols: Number of contiguous columns held in the subgroup
5048
// skip_factor: n, where we include elements from every n-th row of the JM
@@ -55,8 +53,8 @@ extern __constant int __JointMatrixLoadStoreOpt;
5553
// 13 14 15 16
5654
// if skip_factor == 2, we will include items <1, 9> (every "2"nd row) in the
5755
// first WI, <2, 10> in the second WI and so on..
58-
#define IND_ROW_MAJOR(slid, stride, skip_factor, i, sg_cols) (slid/sg_cols*stride + slid%sg_cols + i*stride*skip_factor)
59-
#define IND_COL_MAJOR(slid, stride, skip_factor, i, sg_cols) (i + (slid * stride))
56+
#define IND_ROW_MAJOR(slid, stride, skip_factor, i, sg_cols) ((slid/sg_cols + i*skip_factor)*stride + (slid%sg_cols))
57+
#define IND_COL_MAJOR(slid, stride, skip_factor, i, sg_cols) ((slid/sg_cols + i*skip_factor) + (slid%sg_cols)*stride)
6058
#define IND_VNNI_TX(slid, stride, skip_factor, i, sg_cols) (i + (slid * stride))
6159

6260
// no int7, int6, int5 types
@@ -164,6 +162,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
164162
#define DEFINE_BLOCK_RW_NAME1(rw, us) intel_sub_group_block_##rw##us
165163

166164
#define DEFINE_BLOCK2D_RW_NAME(rw, tx, contrib_bitwidth, M, K) __builtin_IB_subgroup_block_##rw##_flat##tx##_u##contrib_bitwidth##_m##M##k##K##v1
165+
#define DEFINE_BLOCK2D_TRANSPOSE_NAME(contrib_bitwidth, K) __builtin_IB_subgroup_block_read_flat_transpose_u##contrib_bitwidth##_k##K
167166
#define DEFINE_BLOCK2D_VNNI_NAME(contrib_bitwidth, K) __builtin_IB_subgroup_block_read_flat_transform_u##contrib_bitwidth##_k##K
168167

169168
/* For platforms without SG16 JointMatrix support block2d is not available. The
@@ -195,12 +194,14 @@ extern __constant int __JointMatrixLoadStoreOpt;
195194
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
196195
int width = (sizeof (element_type)) * stride - 1; /* in bytes */ \
197196
int pitch = width; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
198-
int height = M - 1; /* row count */ \
197+
int height = K - 1; /* column count */ \
199198
long x = (offset - baseoffset) / (sizeof (contrib_type)); /* in elements */ \
200199
int2 coords = (int2)(x, 0); \
201-
OUT_VEC##M(u##contrib_type) DEFINE_BLOCK2D_RW_NAME(read, _transpose, elem_bitwidth, M, K)(long, int, int, int, int2); \
202-
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_RW_NAME(read, _transpose, elem_bitwidth, M, K)(baseoffset, width, height, pitch, coords); \
203-
*(__private OUT_VEC##M(u##contrib_type) *)dst = res; \
200+
/* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
201+
/* Right now we only support the __builtin_IB_subgroup_block_read_flat_transpose_u32_k8 configuration */ \
202+
OUT_VEC8(u##contrib_type) DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, 8)(long, int, int, int, int2); \
203+
OUT_VEC8(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, 8)(baseoffset, width, height, pitch, coords); \
204+
*(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
204205
return;
205206

206207
#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX(element_type, elem_bitwidth, contrib_type, contrib_bitwidth, M, K, stride_opt) \
@@ -247,7 +248,8 @@ extern __constant int __JointMatrixLoadStoreOpt;
247248
INLINE void MANGLE_LOAD_NAME_##address_space(layout, sg, elem_bitwidth, shape, WI_rows) (__private char *dst, char *mem, long stride) { \
248249
int sg_size = get_sub_group_size(); \
249250
if (WI_rows == M && __JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
250-
&& (order == _ROW_MAJOR || order == _VNNI_TX) && address_space == AS_GLOBAL \
251+
&& (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32)) \
252+
&& address_space == AS_GLOBAL \
251253
) { \
252254
/* It seems __builtin_IB_subgroup_block_rw always needs k=16 \
253255
Maybe it is number of columns divided by pack factor which always gives 16 on SG16 HW */ \
@@ -392,6 +394,16 @@ DEFINE_LOAD(Accumulator_RowMajor, , int, 32, int, 32, 3, 8, 3x8, ROW_MAJOR, , 3,
392394
DEFINE_LOAD(Accumulator_RowMajor, , int, 32, int, 32, 2, 8, 2x8, ROW_MAJOR, , 2, 8)
393395
DEFINE_LOAD(Accumulator_RowMajor, , int, 32, int, 32, 1, 8, 1x8, ROW_MAJOR, , 1, 8)
394396

397+
/* Accumulator load i32 SG8 with transpose */
398+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 8, 8, 8x8, COL_MAJOR, , 8, 8)
399+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 7, 8, 7x8, COL_MAJOR, , 7, 8)
400+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 6, 8, 6x8, COL_MAJOR, , 6, 8)
401+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 5, 8, 5x8, COL_MAJOR, , 5, 8)
402+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 4, 8, 4x8, COL_MAJOR, , 4, 8)
403+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 3, 8, 3x8, COL_MAJOR, , 3, 8)
404+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 2, 8, 2x8, COL_MAJOR, , 2, 8)
405+
DEFINE_LOAD(Accumulator_ColumnMajor, , int, 32, int, 32, 1, 8, 1x8, COL_MAJOR, , 1, 8)
406+
395407
/* SG16*/
396408
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, ROW_MAJOR, , 8, 16)
397409
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, ROW_MAJOR, , 7, 16)
@@ -402,6 +414,16 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJO
402414
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, ROW_MAJOR, , 2, 16)
403415
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJOR, , 1, 16)
404416

417+
/* Accumulator load i32 SG16 with transpose */
418+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, COL_MAJOR, , 8, 16)
419+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, COL_MAJOR, , 7, 16)
420+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 6, 16, 6x16, COL_MAJOR, , 6, 16)
421+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 5, 16, 5x16, COL_MAJOR, , 5, 16)
422+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 4, 16, 4x16, COL_MAJOR, , 4, 16)
423+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, COL_MAJOR, , 3, 16)
424+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, COL_MAJOR, , 2, 16)
425+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, COL_MAJOR, , 1, 16)
426+
405427
/* SG16 for subgroup 32*/
406428
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, ROW_MAJOR, , 4, 16)
407429
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, ROW_MAJOR, , 4, 16)
@@ -412,6 +434,16 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJO
412434
DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, ROW_MAJOR, , 1, 16)
413435
// DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJOR, , 1, 16) same as for subgroup 16
414436

437+
/* Accumulator load i32 SG16 for subgroup 32 with transpose */
438+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, COL_MAJOR, , 4, 16)
439+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, COL_MAJOR, , 4, 16)
440+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 6, 16, 6x16, COL_MAJOR, , 3, 16)
441+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 5, 16, 5x16, COL_MAJOR, , 3, 16)
442+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 4, 16, 4x16, COL_MAJOR, , 2, 16)
443+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, COL_MAJOR, , 2, 16)
444+
DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, COL_MAJOR, , 1, 16)
445+
// DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, COL_MAJOR, , 1, 16) same as for subgroup 16
446+
415447
// --------- STORE built-ins --------------------------------------
416448

417449
#define MANGLE_STORE_NAME_AS_GENERIC(layout, sg, elem_bitwidth, shape, WI_rows) \
@@ -580,6 +612,16 @@ DEFINE_STORE(Accumulator_RowMajor, , int, 32, int, 32, 3, 8, 3x8, ROW_MAJOR, , 3
580612
DEFINE_STORE(Accumulator_RowMajor, , int, 32, int, 32, 2, 8, 2x8, ROW_MAJOR, , 2, 8, true)
581613
DEFINE_STORE(Accumulator_RowMajor, , int, 32, int, 32, 1, 8, 1x8, ROW_MAJOR, , 1, 8, true)
582614

615+
/* Accumulator store i32 SG8 with transpose */
616+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 8, 8, 8x8, COL_MAJOR, , 8, 8, true)
617+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 7, 8, 7x8, COL_MAJOR, , 7, 8, true)
618+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 6, 8, 6x8, COL_MAJOR, , 6, 8, true)
619+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 5, 8, 5x8, COL_MAJOR, , 5, 8, true)
620+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 4, 8, 4x8, COL_MAJOR, , 4, 8, true)
621+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 3, 8, 3x8, COL_MAJOR, , 3, 8, true)
622+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 2, 8, 2x8, COL_MAJOR, , 2, 8, true)
623+
DEFINE_STORE(Accumulator_ColumnMajor, , int, 32, int, 32, 1, 8, 1x8, COL_MAJOR, , 1, 8, true)
624+
583625
/* Acc i32 SG16 */
584626
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, ROW_MAJOR, , 8, 16, true)
585627
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, ROW_MAJOR, , 7, 16, true)
@@ -590,6 +632,16 @@ DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJ
590632
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, ROW_MAJOR, , 2, 16, true)
591633
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJOR, , 1, 16, true)
592634

635+
/* Accumulator store i32 SG16 with transpose */
636+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, COL_MAJOR, , 8, 16, true)
637+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, COL_MAJOR, , 7, 16, true)
638+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 6, 16, 6x16, COL_MAJOR, , 6, 16, true)
639+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 5, 16, 5x16, COL_MAJOR, , 5, 16, true)
640+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 4, 16, 4x16, COL_MAJOR, , 4, 16, true)
641+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, COL_MAJOR, , 3, 16, true)
642+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, COL_MAJOR, , 2, 16, true)
643+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, COL_MAJOR, , 1, 16, true)
644+
593645
/* Acc i32 SG16 for subgroup 32*/
594646
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, ROW_MAJOR, , 4, 16, true)
595647
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, ROW_MAJOR, , 4, 16, true)
@@ -600,6 +652,16 @@ DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJ
600652
DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, ROW_MAJOR, , 1, 16, true)
601653
// DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJOR, , 1, 16, true) same as for subgroup 16
602654

655+
/* Accumulator store i32 SG16 for subgroup 32 with transpose */
656+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 8, 16, 8x16, COL_MAJOR, , 4, 16, true)
657+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 7, 16, 7x16, COL_MAJOR, , 4, 16, true)
658+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 6, 16, 6x16, COL_MAJOR, , 3, 16, true)
659+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 5, 16, 5x16, COL_MAJOR, , 3, 16, true)
660+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 4, 16, 4x16, COL_MAJOR, , 2, 16, true)
661+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, COL_MAJOR, , 2, 16, true)
662+
DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 2, 16, 2x16, COL_MAJOR, , 1, 16, true)
663+
// DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, COL_MAJOR, , 1, 16, true) same as for subgroup 16
664+
603665
/* get_coord() support: */
604666

605667
#define MANGLE_GETCOORD_NAME(layout, sg, elem_bitwidth, R, C) \

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass/JointMatrixFuncsResolutionPass.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,8 @@ static SupportedParams getSupportedParams(const JointMatrixTypeDescription *desc
367367
params.maxRows = maxSliceBitWidth / desc->bitWidth;
368368
params.columns = useSG16 ? 16 : 8;
369369
params.bitWidth = 8 | 32;
370-
params.layouts = 1 << LayoutRowMajor;
370+
params.layouts |= 1 << LayoutRowMajor;
371+
params.layouts |= 1 << LayoutColumnMajor;
371372
}
372373
return params;
373374
}

0 commit comments

Comments
 (0)