@@ -42,9 +42,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
42
42
#define ATTRIBUTE_AS_GLOBAL __global
43
43
44
44
// Index for row major layout is calculated based on that sub group size may be
45
- // bigger than N. Currently we do not have any test case
46
- // of col_major matrix for such case, therefore we keep the implementation of
47
- // COL_MAJOR simple.
45
+ // bigger than N.
48
46
// Arguments:
49
47
// sg_cols: Number of contiguous columns held in the subgroup
50
48
// skip_factor: n, where we include elements from every n-th row of the JM
@@ -55,8 +53,8 @@ extern __constant int __JointMatrixLoadStoreOpt;
55
53
// 13 14 15 16
56
54
// if skip_factor == 2, we will include items <1, 9> (every "2"nd row) in the
57
55
// first WI, <2, 10> in the second WI and so on..
58
- #define IND_ROW_MAJOR (slid , stride , skip_factor , i , sg_cols ) (slid/sg_cols* stride + slid%sg_cols + i*stride*skip_factor )
59
- #define IND_COL_MAJOR (slid , stride , skip_factor , i , sg_cols ) (i + (slid * stride) )
56
+ #define IND_ROW_MAJOR (slid , stride , skip_factor , i , sg_cols ) (( slid/sg_cols + i*skip_factor)* stride + ( slid%sg_cols) )
57
+ #define IND_COL_MAJOR (slid , stride , skip_factor , i , sg_cols ) ((slid/sg_cols + i*skip_factor) + (slid%sg_cols)* stride)
60
58
#define IND_VNNI_TX (slid , stride , skip_factor , i , sg_cols ) (i + (slid * stride))
61
59
62
60
// no int7, int6, int5 types
@@ -164,6 +162,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
164
162
#define DEFINE_BLOCK_RW_NAME1 (rw , us ) intel_sub_group_block_##rw##us
165
163
166
164
#define DEFINE_BLOCK2D_RW_NAME (rw , tx , contrib_bitwidth , M , K ) __builtin_IB_subgroup_block_##rw##_flat##tx##_u##contrib_bitwidth##_m##M##k##K##v1
165
+ #define DEFINE_BLOCK2D_TRANSPOSE_NAME (contrib_bitwidth , K ) __builtin_IB_subgroup_block_read_flat_transpose_u##contrib_bitwidth##_k##K
167
166
#define DEFINE_BLOCK2D_VNNI_NAME (contrib_bitwidth , K ) __builtin_IB_subgroup_block_read_flat_transform_u##contrib_bitwidth##_k##K
168
167
169
168
/* For platforms without SG16 JointMatrix support block2d is not available. The
@@ -195,12 +194,14 @@ extern __constant int __JointMatrixLoadStoreOpt;
195
194
long baseoffset = offset & (~0x3f); /* align to 64-byte */ \
196
195
int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
197
196
int pitch = width ; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
198
- int height = M - 1 ; /* row count */ \
197
+ int height = K - 1 ; /* column count */ \
199
198
long x = (offset - baseoffset ) / (sizeof (contrib_type )); /* in elements */ \
200
199
int2 coords = (int2 )(x , 0 ); \
201
- OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_RW_NAME(read, _transpose, elem_bitwidth, M, K)(long, int, int, int, int2); \
202
- OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_RW_NAME(read, _transpose, elem_bitwidth, M, K)(baseoffset, width, height, pitch, coords); \
203
- *(__private OUT_VEC##M(u##contrib_type) *)dst = res; \
200
+ /* 2D block read transpose builtin requires K value _after_ the transpose operation is done - which is equal to M before the transpose */ \
201
+ /* Right now we only support the __builtin_IB_subgroup_block_read_flat_transpose_u32_k8 configuration */ \
202
+ OUT_VEC8 (u ##contrib_type ) DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, 8)(long, int, int, int, int2); \
203
+ OUT_VEC8(u##contrib_type) res = DEFINE_BLOCK2D_TRANSPOSE_NAME(elem_bitwidth, 8)(baseoffset, width, height, pitch, coords); \
204
+ *(__private OUT_VEC##M(u##contrib_type) *)dst = *(__private OUT_VEC##M(u##contrib_type) *)&res; \
204
205
return;
205
206
206
207
#define IMPLEMENT_BLOCK2D_LOAD_SG16_VNNI_TX (element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , stride_opt ) \
@@ -247,7 +248,8 @@ extern __constant int __JointMatrixLoadStoreOpt;
247
248
INLINE void MANGLE_LOAD_NAME_##address_space(layout, sg, elem_bitwidth, shape, WI_rows) (__private char *dst, char *mem, long stride) { \
248
249
int sg_size = get_sub_group_size(); \
249
250
if (WI_rows == M && __JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) \
250
- && (order == _ROW_MAJOR || order == _VNNI_TX) && address_space == AS_GLOBAL \
251
+ && (order == _ROW_MAJOR || order == _VNNI_TX || (order == _COL_MAJOR && contrib_bitwidth == 32)) \
252
+ && address_space == AS_GLOBAL \
251
253
) { \
252
254
/* It seems __builtin_IB_subgroup_block_rw always needs k=16 \
253
255
Maybe it is number of columns divided by pack factor which always gives 16 on SG16 HW */ \
@@ -392,6 +394,16 @@ DEFINE_LOAD(Accumulator_RowMajor, , int, 32, int, 32, 3, 8, 3x8, ROW_MAJOR, , 3,
392
394
DEFINE_LOAD (Accumulator_RowMajor , , int , 32 , int , 32 , 2 , 8 , 2 x8 , ROW_MAJOR , , 2 , 8 )
393
395
DEFINE_LOAD (Accumulator_RowMajor , , int , 32 , int , 32 , 1 , 8 , 1 x8 , ROW_MAJOR , , 1 , 8 )
394
396
397
+ /* Accumulator load i32 SG8 with transpose */
398
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , 8 , 8 )
399
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 7 , 8 , 7 x8 , COL_MAJOR , , 7 , 8 )
400
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 6 , 8 , 6 x8 , COL_MAJOR , , 6 , 8 )
401
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 5 , 8 , 5 x8 , COL_MAJOR , , 5 , 8 )
402
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 4 , 8 , 4 x8 , COL_MAJOR , , 4 , 8 )
403
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 3 , 8 , 3 x8 , COL_MAJOR , , 3 , 8 )
404
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 2 , 8 , 2 x8 , COL_MAJOR , , 2 , 8 )
405
+ DEFINE_LOAD (Accumulator_ColumnMajor , , int , 32 , int , 32 , 1 , 8 , 1 x8 , COL_MAJOR , , 1 , 8 )
406
+
395
407
/* SG16*/
396
408
DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 8 , 16 )
397
409
DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , ROW_MAJOR , , 7 , 16 )
@@ -402,6 +414,16 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJO
402
414
DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , ROW_MAJOR , , 2 , 16 )
403
415
DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 1 , 16 , 1 x16 , ROW_MAJOR , , 1 , 16 )
404
416
417
+ /* Accumulator load i32 SG16 with transpose */
418
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , COL_MAJOR , , 8 , 16 )
419
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , COL_MAJOR , , 7 , 16 )
420
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 6 , 16 , 6 x16 , COL_MAJOR , , 6 , 16 )
421
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 5 , 16 , 5 x16 , COL_MAJOR , , 5 , 16 )
422
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 4 , 16 , 4 x16 , COL_MAJOR , , 4 , 16 )
423
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 3 , 16 , 3 x16 , COL_MAJOR , , 3 , 16 )
424
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , COL_MAJOR , , 2 , 16 )
425
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 1 , 16 , 1 x16 , COL_MAJOR , , 1 , 16 )
426
+
405
427
/* SG16 for subgroup 32*/
406
428
DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 4 , 16 )
407
429
DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , ROW_MAJOR , , 4 , 16 )
@@ -412,6 +434,16 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJO
412
434
DEFINE_LOAD (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , ROW_MAJOR , , 1 , 16 )
413
435
// DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJOR, , 1, 16) same as for subgroup 16
414
436
437
+ /* Accumulator load i32 SG16 for subgroup 32 with transpose */
438
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , COL_MAJOR , , 4 , 16 )
439
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , COL_MAJOR , , 4 , 16 )
440
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 6 , 16 , 6 x16 , COL_MAJOR , , 3 , 16 )
441
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 5 , 16 , 5 x16 , COL_MAJOR , , 3 , 16 )
442
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 4 , 16 , 4 x16 , COL_MAJOR , , 2 , 16 )
443
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 3 , 16 , 3 x16 , COL_MAJOR , , 2 , 16 )
444
+ DEFINE_LOAD (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , COL_MAJOR , , 1 , 16 )
445
+ // DEFINE_LOAD(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, COL_MAJOR, , 1, 16) same as for subgroup 16
446
+
415
447
// --------- STORE built-ins --------------------------------------
416
448
417
449
#define MANGLE_STORE_NAME_AS_GENERIC (layout , sg , elem_bitwidth , shape , WI_rows ) \
@@ -580,6 +612,16 @@ DEFINE_STORE(Accumulator_RowMajor, , int, 32, int, 32, 3, 8, 3x8, ROW_MAJOR, , 3
580
612
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 2 , 8 , 2 x8 , ROW_MAJOR , , 2 , 8 , true)
581
613
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 1 , 8 , 1 x8 , ROW_MAJOR , , 1 , 8 , true)
582
614
615
+ /* Accumulator store i32 SG8 with transpose */
616
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , 8 , 8 , true)
617
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 7 , 8 , 7 x8 , COL_MAJOR , , 7 , 8 , true)
618
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 6 , 8 , 6 x8 , COL_MAJOR , , 6 , 8 , true)
619
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 5 , 8 , 5 x8 , COL_MAJOR , , 5 , 8 , true)
620
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 4 , 8 , 4 x8 , COL_MAJOR , , 4 , 8 , true)
621
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 3 , 8 , 3 x8 , COL_MAJOR , , 3 , 8 , true)
622
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 2 , 8 , 2 x8 , COL_MAJOR , , 2 , 8 , true)
623
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 1 , 8 , 1 x8 , COL_MAJOR , , 1 , 8 , true)
624
+
583
625
/* Acc i32 SG16 */
584
626
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 8 , 16 , true)
585
627
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , ROW_MAJOR , , 7 , 16 , true)
@@ -590,6 +632,16 @@ DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJ
590
632
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , ROW_MAJOR , , 2 , 16 , true)
591
633
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 1 , 16 , 1 x16 , ROW_MAJOR , , 1 , 16 , true)
592
634
635
+ /* Accumulator store i32 SG16 with transpose */
636
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , COL_MAJOR , , 8 , 16 , true)
637
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , COL_MAJOR , , 7 , 16 , true)
638
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 6 , 16 , 6 x16 , COL_MAJOR , , 6 , 16 , true)
639
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 5 , 16 , 5 x16 , COL_MAJOR , , 5 , 16 , true)
640
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 4 , 16 , 4 x16 , COL_MAJOR , , 4 , 16 , true)
641
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 3 , 16 , 3 x16 , COL_MAJOR , , 3 , 16 , true)
642
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , COL_MAJOR , , 2 , 16 , true)
643
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 1 , 16 , 1 x16 , COL_MAJOR , , 1 , 16 , true)
644
+
593
645
/* Acc i32 SG16 for subgroup 32*/
594
646
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 4 , 16 , true)
595
647
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , ROW_MAJOR , , 4 , 16 , true)
@@ -600,6 +652,16 @@ DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJ
600
652
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , ROW_MAJOR , , 1 , 16 , true)
601
653
// DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJOR, , 1, 16, true) same as for subgroup 16
602
654
655
+ /* Accumulator store i32 SG16 for subgroup 32 with transpose */
656
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , COL_MAJOR , , 4 , 16 , true)
657
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , COL_MAJOR , , 4 , 16 , true)
658
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 6 , 16 , 6 x16 , COL_MAJOR , , 3 , 16 , true)
659
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 5 , 16 , 5 x16 , COL_MAJOR , , 3 , 16 , true)
660
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 4 , 16 , 4 x16 , COL_MAJOR , , 2 , 16 , true)
661
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 3 , 16 , 3 x16 , COL_MAJOR , , 2 , 16 , true)
662
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , COL_MAJOR , , 1 , 16 , true)
663
+ // DEFINE_STORE(Accumulator_ColumnMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, COL_MAJOR, , 1, 16, true) same as for subgroup 16
664
+
603
665
/* get_coord() support: */
604
666
605
667
#define MANGLE_GETCOORD_NAME (layout , sg , elem_bitwidth , R , C ) \
0 commit comments