Skip to content

Commit cc3d89a

Browse files
arnamoy10pszymich
authored andcommitted
Enable store to PackedA RowMajor to support element wise
This change adds support for store to a PackedA Rewmajor matrix by adding builtins. It also fixes a few things in the JointMatrixResolution pass to make sure slice indexing occur properly and no cast issue happens.
1 parent 830d8dd commit cc3d89a

File tree

2 files changed

+51
-12
lines changed

2 files changed

+51
-12
lines changed

IGC/BiFModule/Languages/OpenCL/PreRelease/IBiF_matrix.cl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,3 +258,19 @@ INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_RowMajor_8x16_i3
258258
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_Accumulator_ColumnMajor_8x8_i32_pi64_v8i8(char *mem, int8 row, int stride) {
259259
STORE_ACC_COL_MAJOR(mem, stride, 8)
260260
}
261+
262+
#define STORE_PACKED_A_ROW_MAJOR(mem, stride, element_type, contrib_type, M) \
263+
contrib_type *ptr = (contrib_type *)mem; \
264+
int slid = get_sub_group_local_id(); \
265+
int pack_factor = sizeof (contrib_type) / sizeof (element_type); \
266+
stride = stride / pack_factor; \
267+
for (int i = 0; i < M; i++) \
268+
ptr[slid + i * stride] = row[i]; \
269+
270+
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_PackedA_RowMajor_SG16_8x16_i16_pi64_v8i8(char *mem, short8 row, int stride) {
271+
STORE_PACKED_A_ROW_MAJOR(mem, stride, short, short, 8)
272+
}
273+
274+
INLINE void __builtin_spriv_OpJointMatrixStoreINTEL_PackedA_RowMajor_SG16_8x32_i8_pi64_v8i8(char *mem, short8 row, int stride) {
275+
STORE_PACKED_A_ROW_MAJOR(mem, stride, char, short, 8)
276+
}

IGC/Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass.cpp

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -446,16 +446,23 @@ static Type *getResolvedVectorElementType(Type *matrixType) {
446446
return ty->getElementType();
447447
}
448448

449-
static int getSliceSize(const JointMatrixTypeDescription *desc) {
449+
static int getSliceSize(const JointMatrixTypeDescription *desc, Type *matTy) {
450+
IGCLLVM::FixedVectorType *ty = dyn_cast<IGCLLVM::FixedVectorType>(matTy);
451+
IGC_ASSERT_MESSAGE(ty, "Expecting vector type in calculating slice size");
452+
453+
IntegerType *vecElemType = dyn_cast<IntegerType>(ty->getElementType());
454+
IGC_ASSERT_MESSAGE(vecElemType, "Expecting integer type for vector element.");
455+
456+
unsigned contribTypeWidth = vecElemType->getBitWidth();
450457
if (desc->layout == LayoutRowMajor) {
451458
return desc->rows;
452459
}
453460
if (desc->bitWidth != 0) {
454461
if (desc->layout == LayoutPackedA) {
455-
return desc->rows * (32 / desc->bitWidth);
462+
return desc->rows * (contribTypeWidth / desc->bitWidth);
456463
}
457464
if (desc->layout == LayoutPackedB) {
458-
return 8 * (32 / desc->bitWidth);
465+
return 8 * (contribTypeWidth / desc->bitWidth);
459466
}
460467
}
461468
IGC_ASSERT_MESSAGE(true, "Unexpected matrix layout.");
@@ -511,7 +518,7 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
511518
Type *matTy = ResolveType(CI->getType(), &desc);
512519

513520
IRBuilder builder(CI);
514-
const int sliceSize = getSliceSize(&desc);
521+
const int sliceSize = getSliceSize(&desc, matTy);
515522
const int vectorSize = getResolvedVectorSize(matTy);
516523
/* Case with packing: */
517524
if (sliceSize > vectorSize) {
@@ -523,6 +530,14 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
523530
IGC_ASSERT_MESSAGE(false, "Malformed matrix slice.");
524531
}
525532

533+
if (fillValue->getType()->isPointerTy())
534+
{
535+
IntegerType *vectorElementType = dyn_cast<IntegerType>(getResolvedVectorElementType(matTy));
536+
PointerType *PT = dyn_cast<PointerType>(fillValue->getType());
537+
fillValue = builder.CreateBitCast(fillValue, PointerType::get(vectorElementType, PT->getAddressSpace()));
538+
fillValue = builder.CreateLoad(vectorElementType, fillValue);
539+
}
540+
526541
Value *slice = UndefValue::get(matTy);
527542
for (int i = 0; i < vectorSize; i++) {
528543
slice = builder.CreateInsertElement(slice, fillValue, i);
@@ -534,9 +549,9 @@ Value *JointMatrixFuncsResolutionPass::ResolveFill(CallInst *CI) {
534549

535550
Value *JointMatrixFuncsResolutionPass::ResolveWILength(CallInst *CI) {
536551
JointMatrixTypeDescription desc;
537-
ResolveType(CI->getArgOperand(0)->getType(), &desc);
552+
Type *matTy = ResolveType(CI->getArgOperand(0)->getType(), &desc);
538553

539-
const int sliceSize = getSliceSize(&desc);
554+
const int sliceSize = getSliceSize(&desc, matTy);
540555
Value *lenght = ConstantInt::get(CI->getType(), sliceSize, "matrix.slice.size");
541556

542557
CI->replaceAllUsesWith(lenght);
@@ -546,8 +561,8 @@ Value *JointMatrixFuncsResolutionPass::ResolveWILength(CallInst *CI) {
546561

547562
template <class BuilderT>
548563
static Value *createSliceExtract
549-
(BuilderT *builder, Value *matrix, Value *index, const JointMatrixTypeDescription *desc) {
550-
const int sliceSize = getSliceSize(desc);
564+
(BuilderT *builder, Value *matrix, Value *index, const JointMatrixTypeDescription *desc, Type *matTy) {
565+
const int sliceSize = getSliceSize(desc, matTy);
551566
const int vectorSize = getResolvedVectorSize(matrix->getType());
552567
/* Unpacking: */
553568
if (sliceSize > vectorSize) {
@@ -568,12 +583,12 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
568583
IGCLLVM::FixedVectorType *matTy = dyn_cast<IGCLLVM::FixedVectorType>(rawMatTy);
569584

570585
IRBuilder builder(CI);
571-
const int sliceSize = getSliceSize(&desc);
586+
const int sliceSize = getSliceSize(&desc, rawMatTy);
572587
const int vectorSize = getResolvedVectorSize(matTy);
573588

574589
Value *slice = nullptr;
575590
if (sliceSize > vectorSize) {
576-
Value *element = createSliceExtract(&builder, matrix, index, &desc);
591+
Value *element = createSliceExtract(&builder, matrix, index, &desc, rawMatTy);
577592
if (!isa<IntegerType>(element->getType())) {
578593
unsigned vecElemSize = matTy->getElementType()->getScalarSizeInBits();
579594
element = builder.CreateBitCast(element, Type::getIntNTy(builder.getContext(), vecElemSize));
@@ -603,6 +618,10 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceInsert(CallInst *CI) {
603618
component = builder.CreateShl(component, offset);
604619
component = builder.CreateOr(element, component);
605620
}
621+
622+
IntegerType *vectorElementType = dyn_cast<IntegerType>(getResolvedVectorElementType(rawMatTy));
623+
component = builder.CreateBitCast(component, vectorElementType);
624+
606625
slice = builder.CreateInsertElement(matrix, component, index);
607626

608627
InstsToErase.insert(CI);
@@ -617,9 +636,9 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
617636
Type *matTy = ResolveType(CI->getArgOperand(0)->getType(), &desc);
618637

619638
IRBuilder builder(CI);
620-
Value *element = createSliceExtract(&builder, matrix, index, &desc);
639+
Value *element = createSliceExtract(&builder, matrix, index, &desc, matTy);
621640
/* Unpacking: */
622-
const int sliceSize = getSliceSize(&desc);
641+
const int sliceSize = getSliceSize(&desc, matTy);
623642
const int vectorSize = getResolvedVectorSize(matTy);
624643
if (sliceSize > vectorSize) {
625644
index = builder.CreateTruncOrBitCast(index, element->getType());
@@ -634,6 +653,10 @@ Value *JointMatrixFuncsResolutionPass::ResolveSliceExtract(CallInst *CI) {
634653
element = builder.CreateBitCast(element, CI->getType());
635654
}
636655

656+
// We need the bitcast, especially for half, as the function call that is
657+
// being replaces has a half return type and the vectorElementType is i16
658+
element = builder.CreateBitCast(element, CI->getType());
659+
637660
CI->replaceAllUsesWith(element);
638661
InstsToErase.insert(CI);
639662
return element;

0 commit comments

Comments
 (0)