Skip to content

Commit 06040e4

Browse files
dmitryryinteligcbot
authored andcommitted
Support vector of pointers bitcasts in GenXPromoteArray
1 parent 14647fd commit 06040e4

File tree

1 file changed

+72
-55
lines changed

1 file changed

+72
-55
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXPromoteArray.cpp

Lines changed: 72 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -485,10 +485,14 @@ static bool CheckAllocaUsesInternal(Instruction *I) {
485485
llvm::dyn_cast<llvm::BitCastInst>(*use_it)) {
486486
if (pBitCast->use_empty())
487487
continue;
488-
Type *baseT =
489-
GetBaseType(pBitCast->getType()->getPointerElementType(), nullptr);
490-
Type *sourceType = GetBaseType(
491-
pBitCast->getOperand(0)->getType()->getPointerElementType(), nullptr);
488+
Type *baseT = GetBaseType(
489+
pBitCast->getType()->getScalarType()->getPointerElementType(),
490+
nullptr);
491+
Type *sourceType = GetBaseType(pBitCast->getOperand(0)
492+
->getType()
493+
->getScalarType()
494+
->getPointerElementType(),
495+
nullptr);
492496
IGC_ASSERT(sourceType);
493497
// either the point-to-element-type is the same or
494498
// the point-to-element-type is the byte or a function pointer
@@ -648,10 +652,11 @@ void TransposeHelper::EraseDeadCode() {
648652

649653
void TransposeHelper::handleBCInst(BitCastInst &BC, GenericVectorIndex Idx) {
650654
m_toBeRemoved.push_back(&BC);
651-
Type *DstDerefTy =
652-
GetBaseType(BC.getType()->getPointerElementType(), nullptr);
655+
Type *DstDerefTy = GetBaseType(
656+
BC.getType()->getScalarType()->getPointerElementType(), nullptr);
653657
Type *SrcDerefTy = GetBaseType(
654-
BC.getOperand(0)->getType()->getPointerElementType(), nullptr);
658+
BC.getOperand(0)->getType()->getScalarType()->getPointerElementType(),
659+
nullptr);
655660
IGC_ASSERT(DstDerefTy);
656661
IGC_ASSERT(SrcDerefTy);
657662
// either the point-to-element-type is the same or
@@ -855,27 +860,52 @@ void TransposeHelper::handlePHINode(PHINode *pPhi, GenericVectorIndex Idx,
855860
handleAllocaSources(*pPhi, {NewPhi, Idx.ElementSizeInBits});
856861
}
857862

863+
// Loads vector and casts it if necessary.
864+
// \p CastTo describes vector element type to cast to.
865+
template <typename FolderT = ConstantFolder>
866+
Instruction *loadAndCastVector(AllocaInst &VecAlloca, Type &CastTo,
867+
IRBuilder<FolderT> &IRB) {
868+
auto *LoadVecAlloca = IRB.CreateLoad(&VecAlloca);
869+
auto *AllocatedElemTy = LoadVecAlloca->getType()->getScalarType();
870+
bool IsFuncPointer =
871+
CastTo.isPointerTy() && CastTo.getPointerElementType()->isFunctionTy();
872+
if (AllocatedElemTy == &CastTo || IsFuncPointer)
873+
return LoadVecAlloca;
874+
auto AllocatedWidth = cast<IGCLLVM::FixedVectorType>(LoadVecAlloca->getType())
875+
->getNumElements();
876+
IGC_ASSERT(AllocatedElemTy->getScalarSizeInBits() >=
877+
CastTo.getScalarSizeInBits());
878+
IGC_ASSERT(CastTo.getScalarSizeInBits());
879+
IGC_ASSERT((AllocatedElemTy->getScalarSizeInBits() %
880+
CastTo.getScalarSizeInBits()) == 0);
881+
auto CastedWidth = AllocatedWidth * (AllocatedElemTy->getScalarSizeInBits() /
882+
CastTo.getScalarSizeInBits());
883+
return cast<Instruction>(IRB.CreateBitCast(
884+
LoadVecAlloca, IGCLLVM::FixedVectorType::get(&CastTo, CastedWidth),
885+
"post.load.bc"));
886+
}
887+
888+
// Casts \p NewValue if its type doesn't correspond to allocated vector type,
889+
// then stores the value.
890+
template <typename FolderT = ConstantFolder>
891+
Instruction *castAndStoreVector(AllocaInst &VecAlloca, Value &NewValue,
892+
IRBuilder<FolderT> &IRB) {
893+
auto *CastedValue = &NewValue;
894+
if (VecAlloca.getAllocatedType() != NewValue.getType())
895+
CastedValue = IRB.CreateBitCast(&NewValue, VecAlloca.getAllocatedType(),
896+
NewValue.getName() + ".pre.store.bc");
897+
return IRB.CreateStore(CastedValue, &VecAlloca);
898+
}
899+
858900
void TransposeHelperPromote::handleLoadInst(LoadInst *pLoad,
859901
Value *pScalarizedIdx) {
860902
IGC_ASSERT(pLoad->isSimple());
861903
IRBuilder<> IRB(pLoad);
862-
Value *pLoadVecAlloca = IRB.CreateLoad(pVecAlloca);
863904
auto LdTy = pLoad->getType()->getScalarType();
864-
auto VETy = pLoadVecAlloca->getType()->getScalarType();
865-
auto ReadIn = pLoadVecAlloca;
905+
auto *ReadIn = loadAndCastVector(*pVecAlloca, *LdTy, IRB);
866906
bool IsFuncPointer = pLoad->getPointerOperandType()->isPointerTy() &&
867907
pLoad->getPointerOperandType()->getPointerElementType()->isPointerTy() &&
868908
pLoad->getPointerOperandType()->getPointerElementType()->getPointerElementType()->isFunctionTy();
869-
// do the type-casting if necessary
870-
if (VETy != LdTy && !IsFuncPointer) {
871-
auto VLen = cast<VectorType>(pLoadVecAlloca->getType())->getNumElements();
872-
IGC_ASSERT(VETy->getScalarSizeInBits() >= LdTy->getScalarSizeInBits());
873-
IGC_ASSERT(LdTy->getScalarSizeInBits());
874-
IGC_ASSERT((VETy->getScalarSizeInBits() % LdTy->getScalarSizeInBits()) == 0);
875-
VLen = VLen * (VETy->getScalarSizeInBits() / LdTy->getScalarSizeInBits());
876-
ReadIn =
877-
IRB.CreateBitCast(ReadIn, IGCLLVM::FixedVectorType::get(LdTy, VLen));
878-
}
879909
if (IsFuncPointer) {
880910
Region R(
881911
IGCLLVM::FixedVectorType::get(
@@ -891,7 +921,7 @@ void TransposeHelperPromote::handleLoadInst(LoadInst *pLoad,
891921
pScalarizedIdx = IRB.CreateZExtOrTrunc(pScalarizedIdx, Type::getInt16Ty(pLoad->getContext()));
892922
}
893923
R.Indirect = pScalarizedIdx;
894-
auto *Result = R.createRdRegion(pLoadVecAlloca, pLoad->getName(), pLoad,
924+
auto *Result = R.createRdRegion(ReadIn, pLoad->getName(), pLoad,
895925
pLoad->getDebugLoc(), true);
896926
if (!Result->getType()->isPointerTy()) {
897927
auto *BC =
@@ -931,25 +961,13 @@ void TransposeHelperPromote::handleStoreInst(StoreInst *pStore,
931961
IGC_ASSERT(pStore->isSimple());
932962
IRBuilder<> IRB(pStore);
933963
llvm::Value *pStoreVal = pStore->getValueOperand();
934-
llvm::Value *pLoadVecAlloca = IRB.CreateLoad(pVecAlloca);
935-
llvm::Value *WriteOut = pLoadVecAlloca;
936964
auto *StTy = pStoreVal->getType()->getScalarType();
937-
auto *VETy = pLoadVecAlloca->getType()->getScalarType();
938-
// do the type-casting if necessary
965+
Value *WriteOut = loadAndCastVector(*pVecAlloca, *StTy, IRB);
939966

940967
bool IsFuncPointerStore =
941968
(isFuncPointerVec(pStoreVal) ||
942969
(pStoreVal->getType()->isPointerTy() &&
943970
pStoreVal->getType()->getPointerElementType()->isFunctionTy()));
944-
if (VETy != StTy && !IsFuncPointerStore) {
945-
auto VLen = cast<VectorType>(pLoadVecAlloca->getType())->getNumElements();
946-
IGC_ASSERT(VETy->getScalarSizeInBits() >= StTy->getScalarSizeInBits());
947-
IGC_ASSERT(StTy->getScalarSizeInBits());
948-
IGC_ASSERT((VETy->getScalarSizeInBits() % StTy->getScalarSizeInBits()) == 0);
949-
VLen = VLen * (VETy->getScalarSizeInBits() / StTy->getScalarSizeInBits());
950-
WriteOut =
951-
IRB.CreateBitCast(WriteOut, IGCLLVM::FixedVectorType::get(StTy, VLen));
952-
}
953971
if (IsFuncPointerStore) {
954972
auto *NewStoreVal = pStoreVal;
955973
IGC_ASSERT(cast<VectorType>(pVecAlloca->getType()->getPointerElementType())->getElementType()->isIntegerTy(64));
@@ -1000,10 +1018,7 @@ void TransposeHelperPromote::handleStoreInst(StoreInst *pStore,
10001018
WriteOut =
10011019
IRB.CreateInsertElement(WriteOut, pStoreVal, ScalarizedIdx.Index);
10021020
}
1003-
// cast the vector type back if necessary
1004-
if (VETy != StTy)
1005-
WriteOut = IRB.CreateBitCast(WriteOut, pLoadVecAlloca->getType());
1006-
IRB.CreateStore(WriteOut, pVecAlloca);
1021+
castAndStoreVector(*pVecAlloca, *WriteOut, IRB);
10071022
pStore->eraseFromParent();
10081023
}
10091024

@@ -1152,9 +1167,9 @@ void TransposeHelperPromote::handleLLVMGather(IntrinsicInst *pInst,
11521167
Value *pScalarizedIdx) {
11531168
IRBuilder<> IRB(pInst);
11541169
IGC_ASSERT(pInst->getType()->isVectorTy());
1155-
Value *pLoadVecAlloca = IRB.CreateLoad(pVecAlloca);
11561170
auto N = cast<VectorType>(pInst->getType())->getNumElements();
11571171
auto ElemType = cast<VectorType>(pInst->getType())->getElementType();
1172+
Value *LoadVecAlloca = loadAndCastVector(*pVecAlloca, *ElemType, IRB);
11581173

11591174
// A vector load
11601175
// %v = <2 x float> gather %pred, %vector_of_ptr, %old_value
@@ -1192,8 +1207,8 @@ void TransposeHelperPromote::handleLLVMGather(IntrinsicInst *pInst,
11921207
R.VStride = 0;
11931208
}
11941209
Value *Result =
1195-
R.createRdRegion(pLoadVecAlloca, pInst->getName(), pInst /*InsertBefore*/,
1196-
pInst->getDebugLoc(), true /*AllowScalar*/);
1210+
R.createRdRegion(LoadVecAlloca, pInst->getName(), pInst /*InsertBefore*/,
1211+
pInst->getDebugLoc(), true /*AllowScalar*/);
11971212

11981213
// if old-value is not undefined and predicate is not all-one,
11991214
// create a select auto OldVal = pInst->getArgOperand(3);
@@ -1216,20 +1231,22 @@ void TransposeHelperPromote::handleLLVMScatter(llvm::IntrinsicInst *pInst,
12161231
llvm::Value *pScalarizedIdx) {
12171232
// Add Store instruction to remove list
12181233
IRBuilder<> IRB(pInst);
1219-
llvm::Value *pStoreVal = pInst->getArgOperand(3);
1220-
llvm::Value *pLoadVecAlloca = IRB.CreateLoad(pVecAlloca);
1221-
IGC_ASSERT(pStoreVal->getType()->isVectorTy());
1222-
auto N = cast<VectorType>(pStoreVal->getType())->getNumElements();
1223-
auto ElemType = cast<VectorType>(pStoreVal->getType())->getElementType();
1234+
Value *StoredValue = pInst->getArgOperand(0);
1235+
IGC_ASSERT(StoredValue->getType()->isVectorTy());
1236+
auto N =
1237+
cast<IGCLLVM::FixedVectorType>(StoredValue->getType())->getNumElements();
1238+
auto *ElemType =
1239+
cast<IGCLLVM::FixedVectorType>(StoredValue->getType())->getElementType();
1240+
Value *LoadVecAlloca = loadAndCastVector(*pVecAlloca, *ElemType, IRB);
12241241
// A vector scatter
1225-
// scatter %pred, %ptr, %offset, %newvalue
1242+
// scatter %newvalue, %ptr, i32 align, %pred
12261243
// becomes
1227-
// %w = load <32 x float> *%ptr1
1228-
// %w1 = <32 x float> wrregion %w, newvalue, %offset, %pred
1229-
// store <32 x float> %w1, <32 x float>* %ptr1
1244+
// %w = load %vec.alloca
1245+
// %w1 = wrregion %w, %newvalue, %indexed.ptr, %pred
1246+
// store %w1, %vec.alloca
12301247

12311248
// Create the new wrregion
1232-
Region R(pStoreVal);
1249+
Region R{StoredValue};
12331250
int64_t v0 = 0;
12341251
int64_t diff = 0;
12351252
// pScalarizedIdx is an indice of element, so
@@ -1259,12 +1276,12 @@ void TransposeHelperPromote::handleLLVMScatter(llvm::IntrinsicInst *pInst,
12591276
R.Stride = 0;
12601277
R.VStride = 0;
12611278
}
1262-
R.Mask = pInst->getArgOperand(0);
1279+
R.Mask = pInst->getArgOperand(3);
12631280
auto NewInst = cast<Instruction>(
1264-
R.createWrRegion(pLoadVecAlloca, pStoreVal, pInst->getName(),
1265-
pInst /*InsertBefore*/, pInst->getDebugLoc()));
1281+
R.createWrRegion(LoadVecAlloca, StoredValue, pInst->getName(),
1282+
pInst /*InsertBefore*/, pInst->getDebugLoc()));
12661283

1267-
IRB.CreateStore(NewInst, pVecAlloca);
1284+
castAndStoreVector(*pVecAlloca, *NewInst, IRB);
12681285
pInst->eraseFromParent();
12691286
}
12701287

0 commit comments

Comments
 (0)