@@ -152,7 +152,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
152
152
int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
153
153
int pitch = width ; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
154
154
int height = M - 1 ; /* row count */ \
155
- long x = (offset - baseoffset ) / (sizeof (element_type )); /* in elements */ \
155
+ long x = (offset - baseoffset ) / (sizeof (contrib_type )); /* in elements */ \
156
156
int2 coords = (int2 )(x , 0 ); \
157
157
OUT_VEC ##M (u##contrib_type) DEFINE_BLOCK2D_RW_NAME(read, contrib_bitwidth, M, K)(long, int, int, int, int2); \
158
158
OUT_VEC##M(u##contrib_type) res = DEFINE_BLOCK2D_RW_NAME(read, contrib_bitwidth, M, K)(baseoffset, width, height, pitch, coords); \
@@ -164,7 +164,7 @@ extern __constant int __JointMatrixLoadStoreOpt;
164
164
int width = (sizeof (element_type )) * stride - 1 ; /* in bytes */ \
165
165
int pitch = width ; /* JointMatrices are expected to be contigunous in memory, without padding at the end of a row */ \
166
166
int height = M - 1 ; /* row count */ \
167
- long x = (offset - baseoffset ) / (sizeof (element_type )); /* in elements */ \
167
+ long x = (offset - baseoffset ) / (sizeof (contrib_type )); /* in elements */ \
168
168
int2 coords = (int2 )(x , 0 ); \
169
169
void DEFINE_BLOCK2D_RW_NAME (write , contrib_bitwidth , M , K )(long , int , int , int , int2 , OUT_VEC ##M (u##contrib_type)); \
170
170
OUT_VEC##M(u##contrib_type) val = VEC_TO_VEC##M(u##contrib_type, vec); \
@@ -304,7 +304,7 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJO
304
304
// set block_opt to false to disable block non-continous optimization per one built-in as a workaround
305
305
#define DEFINE_STORE (layout , sg , element_type , elem_bitwidth , contrib_type , contrib_bitwidth , M , K , shape , order , us , stride_opt , block_opt ) \
306
306
INLINE void MANGLE_STORE_NAME(layout, sg, elem_bitwidth, shape) (char *mem, OUT_VEC##M(contrib_type) vec, int stride) { \
307
- if (__JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) && order == ROW_MAJOR) { \
307
+ if (__JointMatrixLoadStoreOpt >= BLOCK2D_IMPL && (M == 2 || M == 4 || M == 8) && order == ROW_MAJOR && elem_bitwidth > 8 ) { \
308
308
IMPLEMENT_BLOCK2D_STORE##sg(element_type, contrib_type, contrib_bitwidth, M, K, vec) \
309
309
} \
310
310
if (__JointMatrixLoadStoreOpt >= VECTOR_CONT_IMPL && stride == stride_opt \
@@ -330,22 +330,19 @@ DEFINE_LOAD(Accumulator_RowMajor, _SG16, int, 32, int, 32, 1, 16, 1x16, ROW_MAJO
330
330
}
331
331
332
332
// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
333
- DEFINE_STORE (PackedA_RowMajor , , char , 8 , int , 32 , 8 , 32 , 8 x32 , ROW_MAJOR , , 32 , false)
334
-
335
- // TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
336
- DEFINE_STORE (PackedA_RowMajor , , short , 16 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , false)
337
-
338
- // TODO: investigate why intel_sub_group_block_write_us causes an assertion and enable blocked non-continuous optimization
339
- DEFINE_STORE (PackedA_RowMajor , _SG16 , char , 8 , short , 16 , 8 , 32 , 8 x32 , ROW_MAJOR , _us , 32 , false)
340
-
341
- // TODO: investigate why intel_sub_group_block_write_us causes an assertion and enable blocked non-continuous optimization
333
+ DEFINE_STORE (PackedA_RowMajor , , char , 8 , int , 32 , 8 , 32 , 8 x32 , ROW_MAJOR , , 32 , false)
334
+ DEFINE_STORE (PackedA_RowMajor , , short , 16 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , false)
335
+ DEFINE_STORE (PackedA_RowMajor , _SG16 , char , 8 , short , 16 , 8 , 32 , 8 x32 , ROW_MAJOR , _us , 32 , false)
342
336
DEFINE_STORE (PackedA_RowMajor , _SG16 , short , 16 , short , 16 , 8 , 16 , 8 x16 , ROW_MAJOR , _us , 16 , false)
343
337
344
- DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x8 , ROW_MAJOR , , 16 , true)
345
- DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x16 , ROW_MAJOR , , 32 , true)
338
+ DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x8 , ROW_MAJOR , , 16 , true)
339
+ DEFINE_STORE (PackedB_PackedB , , short , 16 , int , 32 , 8 , 16 , 16 x16 , ROW_MAJOR , , 32 , true)
340
+ DEFINE_STORE (PackedB_PackedB , _SG16 , short , 16 , int , 32 , 8 , 16 , 16 x8 , ROW_MAJOR , , 16 , true)
341
+ DEFINE_STORE (PackedB_PackedB , _SG16 , short , 16 , int , 32 , 8 , 16 , 16 x16 , ROW_MAJOR , , 32 , true)
346
342
347
343
// TODO: investigate why intel_sub_group_block_write causes an assertion and enable blocked non-continuous optimization
348
- DEFINE_STORE (PackedB_PackedB , , char , 8 , int , 32 , 8 , 32 , 32 x8 , ROW_MAJOR , , 16 , false)
344
+ DEFINE_STORE (PackedB_PackedB , , char , 8 , int , 32 , 8 , 32 , 32 x8 , ROW_MAJOR , , 16 , false)
345
+ DEFINE_STORE (PackedB_PackedB , _SG16 , char , 8 , int , 32 , 8 , 32 , 32 x8 , ROW_MAJOR , , 16 , false)
349
346
350
347
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , ROW_MAJOR , , 8 , true)
351
348
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 7 , 8 , 7 x8 , ROW_MAJOR , , 8 , true)
@@ -356,7 +353,7 @@ DEFINE_STORE(Accumulator_RowMajor, , int, 32, int, 32, 3, 8, 3x8, ROW_MAJOR, , 8
356
353
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 2 , 8 , 2 x8 , ROW_MAJOR , , 8 , true)
357
354
DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 1 , 8 , 1 x8 , ROW_MAJOR , , 8 , true)
358
355
359
- DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , true)
356
+ DEFINE_STORE (Accumulator_RowMajor , , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , true)
360
357
361
358
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 8 , 16 , 8 x16 , ROW_MAJOR , , 16 , true)
362
359
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 7 , 16 , 7 x16 , ROW_MAJOR , , 16 , true)
@@ -367,4 +364,5 @@ DEFINE_STORE(Accumulator_RowMajor, _SG16, int, 32, int, 32, 3, 16, 3x16, ROW_MAJ
367
364
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 2 , 16 , 2 x16 , ROW_MAJOR , , 16 , true)
368
365
DEFINE_STORE (Accumulator_RowMajor , _SG16 , int , 32 , int , 32 , 1 , 16 , 1 x16 , ROW_MAJOR , , 16 , true)
369
366
370
- DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , -1 , false)
367
+ DEFINE_STORE (Accumulator_ColumnMajor , , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , -1 , false)
368
+ DEFINE_STORE (Accumulator_ColumnMajor , _SG16 , int , 32 , int , 32 , 8 , 8 , 8 x8 , COL_MAJOR , , -1 , false)
0 commit comments