Skip to content

Commit adaa145

Browse files
zuban32igcbot
authored andcommitted
Add support for SVM instructions to GenXPromoteArray
1 parent 04a66cb commit adaa145

File tree

1 file changed

+151
-10
lines changed

1 file changed

+151
-10
lines changed

IGC/VectorCompiler/lib/GenXCodeGen/GenXPromoteArray.cpp

Lines changed: 151 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ class TransposeHelper {
116116
void handleAllocaSources(Instruction &Inst, GenericVectorIndex Idx);
117117
void handleGEPInst(GetElementPtrInst *pGEP, GenericVectorIndex Idx);
118118
void handleBCInst(BitCastInst &BC, GenericVectorIndex Idx);
119+
void handlePTIInst(PtrToIntInst &BC, GenericVectorIndex Idx);
119120
void handlePHINode(PHINode *pPhi, GenericVectorIndex pScalarizedIdx,
120121
BasicBlock *pIncomingBB);
121122
virtual void handleLoadInst(llvm::LoadInst *pLoad,
@@ -126,6 +127,10 @@ class TransposeHelper {
126127
llvm::Value *pScalarizedIdx) = 0;
127128
virtual void handlePrivateScatter(llvm::IntrinsicInst *pInst,
128129
llvm::Value *pScalarizedIdx) = 0;
130+
virtual void handleSVMGather(llvm::IntrinsicInst *pInst,
131+
llvm::Value *pScalarizedIdx) = 0;
132+
virtual void handleSVMScatter(llvm::IntrinsicInst *pInst,
133+
llvm::Value *pScalarizedIdx) = 0;
129134
virtual void handleLLVMGather(llvm::IntrinsicInst *pInst,
130135
llvm::Value *pScalarizedIdx) = 0;
131136
virtual void handleLLVMScatter(llvm::IntrinsicInst *pInst,
@@ -221,12 +226,14 @@ namespace {
221226

222227
class TransposeHelperPromote : public TransposeHelper {
223228
public:
224-
void handleLoadInst(LoadInst *pLoad, Value *pScalarizedIdx);
225-
void handleStoreInst(StoreInst *pStore, GenericVectorIndex pScalarizedIdx);
226-
void handlePrivateGather(IntrinsicInst *pInst, Value *pScalarizedIdx);
227-
void handlePrivateScatter(IntrinsicInst *pInst, Value *pScalarizedIdx);
228-
void handleLLVMGather(IntrinsicInst *pInst, Value *pScalarizedIdx);
229-
void handleLLVMScatter(IntrinsicInst *pInst, Value *pScalarizedIdx);
229+
void handleLoadInst(LoadInst *pLoad, Value *pScalarizedIdx) override;
230+
void handleStoreInst(StoreInst *pStore, GenericVectorIndex pScalarizedIdx) override;
231+
void handlePrivateGather(IntrinsicInst *pInst, Value *pScalarizedIdx) override;
232+
void handlePrivateScatter(IntrinsicInst *pInst, Value *pScalarizedIdx) override;
233+
void handleSVMGather(IntrinsicInst *pInst, Value *pScalarizedIdx) override;
234+
void handleSVMScatter(IntrinsicInst *pInst, Value *pScalarizedIdx) override;
235+
void handleLLVMGather(IntrinsicInst *pInst, Value *pScalarizedIdx) override;
236+
void handleLLVMScatter(IntrinsicInst *pInst, Value *pScalarizedIdx) override;
230237

231238
AllocaInst *pVecAlloca;
232239

@@ -406,6 +413,52 @@ static Type *GetBaseType(Type *pType, Type *pBaseType) {
406413
return pType;
407414
}
408415

416+
static bool CheckPtrToIntCandidate(PtrToIntInst *PTI) {
417+
// here we handle only the most common pattern for SVM instructions
418+
// ptrtoint->insertelem->shuffle->arith_op->svm_gather/scatter
419+
// others are possible, but not handled yet
420+
if (!PTI->hasOneUse())
421+
return false;
422+
auto *Insert = dyn_cast<InsertElementInst>(PTI->user_back());
423+
if (!Insert)
424+
return false;
425+
if (!Insert->hasOneUse())
426+
return false;
427+
auto *Shuffle = dyn_cast<ShuffleVectorInst>(Insert->user_back());
428+
if (!Shuffle)
429+
return false;
430+
if (!Shuffle->hasOneUse())
431+
return false;
432+
auto *BinOp = dyn_cast<BinaryOperator>(Shuffle->user_back());
433+
if (!BinOp)
434+
return false;
435+
if (BinOp->user_empty())
436+
return false;
437+
for (auto *MemOp : BinOp->users()) {
438+
if (!isa<CallInst>(MemOp))
439+
return false;
440+
auto IID = GenXIntrinsic::getAnyIntrinsicID(MemOp);
441+
if (IID != GenXIntrinsic::genx_svm_gather &&
442+
IID != GenXIntrinsic::genx_svm_scatter)
443+
return false;
444+
// for now skip insts w/ blockSize > 1
445+
// or weird things like <16 x i32> %res = svm.gather(<8 x i64> offsets, ...)
446+
auto *Pred = MemOp->getOperand(0);
447+
auto *NumBlocks = MemOp->getOperand(1);
448+
auto *Input = MemOp->getOperand(3);
449+
IGC_ASSERT(isa<ConstantInt>(NumBlocks));
450+
if (cast<ConstantInt>(NumBlocks)->getZExtValue() ||
451+
cast<VectorType>(Input->getType())->getNumElements() >
452+
cast<VectorType>(Pred->getType())
453+
->getNumElements() ||
454+
(isa<VectorType>(MemOp->getType()) &&
455+
cast<VectorType>(Pred->getType())->getNumElements() <
456+
cast<VectorType>(MemOp->getType())->getNumElements()))
457+
return false;
458+
}
459+
return true;
460+
}
461+
409462
static bool CheckAllocaUsesInternal(Instruction *I) {
410463
for (Value::user_iterator use_it = I->user_begin(), use_e = I->user_end();
411464
use_it != use_e; ++use_it) {
@@ -440,7 +493,7 @@ static bool CheckAllocaUsesInternal(Instruction *I) {
440493
Type *sourceType = GetBaseType(
441494
pBitCast->getOperand(0)->getType()->getPointerElementType(), nullptr);
442495
IGC_ASSERT(sourceType);
443-
// either the point-to-element-type is the same or
496+
// either the point-to-element-type is the same or
444497
// the point-to-element-type is the byte or a function pointer
445498
if (baseT != nullptr &&
446499
(baseT->getScalarSizeInBits() == 8 ||
@@ -452,6 +505,8 @@ static bool CheckAllocaUsesInternal(Instruction *I) {
452505
}
453506
// Not a candidate.
454507
return false;
508+
} else if (PtrToIntInst *PTI = dyn_cast<PtrToIntInst>(*use_it)) {
509+
return CheckPtrToIntCandidate(PTI);
455510
} else if (IntrinsicInst *intr = dyn_cast<IntrinsicInst>(*use_it)) {
456511
auto IID = GenXIntrinsic::getAnyIntrinsicID(intr);
457512
if (IID == llvm::Intrinsic::lifetime_start ||
@@ -575,9 +630,6 @@ void TransformPrivMem::handleAllocaInst(llvm::AllocaInst *pAlloca) {
575630
llvm::AllocaInst *pVecAlloca = createVectorForAlloca(pAlloca, pBaseType);
576631
if (!pVecAlloca)
577632
return;
578-
// skip processing of allocas that are already fine
579-
if (pVecAlloca->getType() == pAlloca->getType())
580-
return;
581633

582634
IRBuilder<> IRB(pVecAlloca);
583635
GenericVectorIndex StartIdx{
@@ -626,6 +678,24 @@ void TransposeHelper::handleBCInst(BitCastInst &BC, GenericVectorIndex Idx) {
626678
BC, {NewIdx, static_cast<int>(DstDerefTy->getScalarSizeInBits())});
627679
}
628680

681+
void TransposeHelper::handlePTIInst(PtrToIntInst &PTI, GenericVectorIndex Idx) {
682+
IGC_ASSERT(PTI.hasOneUse() && isa<InsertElementInst>(PTI.user_back()));
683+
IRBuilder<> IRB(&PTI);
684+
auto *Insert = PTI.user_back();
685+
auto *CastedIdx = IRB.CreateZExt(Idx.Index, PTI.getType(), PTI.getName());
686+
auto *Mul = IRB.CreateMul(
687+
CastedIdx,
688+
ConstantInt::get(CastedIdx->getType(), Idx.getElementSizeInBytes()), "");
689+
PTI.replaceAllUsesWith(Mul);
690+
PTI.eraseFromParent();
691+
IGC_ASSERT(Insert->hasOneUse() &&
692+
isa<ShuffleVectorInst>(Insert->user_back()));
693+
auto *Shuffle = Insert->user_back();
694+
IGC_ASSERT(Shuffle->hasOneUse() && isa<BinaryOperator>(Shuffle->user_back()));
695+
handleAllocaSources(*(Shuffle->user_back()),
696+
{Shuffle->user_back(), Idx.ElementSizeInBits});
697+
}
698+
629699
void TransposeHelper::handleAllocaSources(Instruction &Inst,
630700
GenericVectorIndex Idx) {
631701
SmallVector<Value *, 10> Users{Inst.user_begin(), Inst.user_end()};
@@ -635,6 +705,8 @@ void TransposeHelper::handleAllocaSources(Instruction &Inst,
635705
handleGEPInst(pGEP, Idx);
636706
} else if (BitCastInst *BC = dyn_cast<BitCastInst>(User)) {
637707
handleBCInst(*BC, Idx);
708+
} else if (PtrToIntInst *PTI = dyn_cast<PtrToIntInst>(User)) {
709+
handlePTIInst(*PTI, Idx);
638710
} else if (StoreInst *pStore = llvm::dyn_cast<StoreInst>(User)) {
639711
handleStoreInst(pStore, Idx);
640712
} else if (LoadInst *pLoad = llvm::dyn_cast<LoadInst>(User)) {
@@ -650,6 +722,10 @@ void TransposeHelper::handleAllocaSources(Instruction &Inst,
650722
handlePrivateGather(IntrInst, Idx.Index);
651723
else if (IID == GenXIntrinsic::genx_scatter_private)
652724
handlePrivateScatter(IntrInst, Idx.Index);
725+
else if (IID == GenXIntrinsic::genx_svm_gather)
726+
handleSVMGather(IntrInst, Idx.Index);
727+
else if (IID == GenXIntrinsic::genx_svm_scatter)
728+
handleSVMScatter(IntrInst, Idx.Index);
653729
else if (IntrInst->getIntrinsicID() == llvm::Intrinsic::masked_gather)
654730
handleLLVMGather(IntrInst, Idx.Index);
655731
else if (IntrInst->getIntrinsicID() == llvm::Intrinsic::masked_scatter)
@@ -1201,4 +1277,69 @@ void TransposeHelperPromote::handleLLVMScatter(llvm::IntrinsicInst *pInst,
12011277
pInst->eraseFromParent();
12021278
}
12031279

1280+
void TransposeHelperPromote::handleSVMGather(IntrinsicInst *pInst,
1281+
Value *pScalarizedIdx) {
1282+
// %v = svm_gather %pred, %ptr + %offset
1283+
// is turned into
1284+
// %v0 = load <32 x float> *%ptr1
1285+
// %v1 = <32 x float> rdregion %v0, %offset, %pred
1286+
1287+
// here we rely on offset being previously generated
1288+
// by e.g. ISPC
1289+
1290+
// part of this is taken from handleLLVMGather above
1291+
IRBuilder<> IRB(pInst);
1292+
llvm::Value *pLoadVecAlloca = IRB.CreateLoad(pVecAlloca);
1293+
Region R(pInst);
1294+
R.Mask = pInst->getArgOperand(0);
1295+
R.Indirect = IRB.CreateTrunc(
1296+
pScalarizedIdx,
1297+
IGCLLVM::FixedVectorType::get(
1298+
IntegerType::getInt16Ty(pInst->getContext()),
1299+
cast<VectorType>(pScalarizedIdx->getType())->getNumElements()),
1300+
"");
1301+
R.Width = 1;
1302+
R.Stride = 0;
1303+
R.VStride = 0;
1304+
Value *Result =
1305+
R.createRdRegion(pLoadVecAlloca, pInst->getName(), pInst /*InsertBefore*/,
1306+
pInst->getDebugLoc(), true /*AllowScalar*/);
1307+
// if old-value is not undefined and predicate is not all-one,
1308+
// create a select
1309+
auto PredVal = pInst->getArgOperand(2);
1310+
bool PredAllOne = false;
1311+
if (auto C = dyn_cast<ConstantVector>(PredVal)) {
1312+
if (auto B = C->getSplatValue())
1313+
PredAllOne = B->isOneValue();
1314+
}
1315+
auto OldVal = pInst->getArgOperand(3);
1316+
if (!PredAllOne && !isa<UndefValue>(OldVal))
1317+
Result = IRB.CreateSelect(PredVal, Result, OldVal);
1318+
1319+
pInst->replaceAllUsesWith(Result);
1320+
pInst->eraseFromParent();
1321+
}
1322+
1323+
void TransposeHelperPromote::handleSVMScatter(IntrinsicInst *pInst,
1324+
Value *pScalarizedIdx) {
1325+
IRBuilder<> IRB(pInst);
1326+
Value *pStoreVal = pInst->getArgOperand(3);
1327+
Value *pLoadVecAlloca = IRB.CreateLoad(pVecAlloca);
1328+
Region R(pStoreVal);
1329+
R.Mask = pInst->getArgOperand(0);
1330+
R.Indirect = IRB.CreateTrunc(
1331+
pScalarizedIdx,
1332+
IGCLLVM::FixedVectorType::get(
1333+
IntegerType::getInt16Ty(pInst->getContext()),
1334+
cast<VectorType>(pScalarizedIdx->getType())->getNumElements()),
1335+
"");
1336+
R.Width = 1;
1337+
R.Stride = 0;
1338+
R.VStride = 0;
1339+
auto NewInst = R.createWrRegion(pLoadVecAlloca, pStoreVal, pInst->getName(),
1340+
pInst, pInst->getDebugLoc());
1341+
IRB.CreateStore(NewInst, pVecAlloca);
1342+
pInst->eraseFromParent();
1343+
}
1344+
12041345
} // namespace

0 commit comments

Comments
 (0)