Skip to content

Commit 04a66cb

Browse files
zuban32igcbot
authored andcommitted
Switch TPM to SVM entirely
1 parent 1fe1f14 commit 04a66cb

File tree

3 files changed

+69
-42
lines changed

3 files changed

+69
-42
lines changed

IGC/VectorCompiler/include/vc/GenXCodeGen/GenXInternalMetadata.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ namespace FunctionMD {
3838
inline constexpr const char GenXKernelInternal[] = "genx.kernel.internal";
3939
}
4040

41+
namespace InstMD {
42+
inline constexpr const char SVMBlockType[] = "SVMBlockType";
43+
}
44+
45+
namespace ModuleMD {
46+
inline constexpr const char UseSVMStack[] = "genx.useGlobalMem";
47+
}
48+
4149
namespace internal {
4250

4351
namespace KernelMDOp {

IGC/VectorCompiler/lib/GenXCodeGen/GenXCisaBuilder.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3633,7 +3633,10 @@ void GenXKernelBuilder::buildIntrinsic(CallInst *CI, unsigned IntrinID,
36333633
Value *V = CI;
36343634
if (!AI.isRet())
36353635
V = CI->getArgOperand(AI.getArgIdx());
3636-
unsigned ElBytes = getResultedTypeSize(V->getType()->getScalarType(), DL);
3636+
auto *EltType = V->getType()->getScalarType();
3637+
if (auto *MDType = CI->getMetadata(InstMD::SVMBlockType))
3638+
EltType = cast<ValueAsMetadata>(MDType->getOperand(0).get())->getType();
3639+
unsigned ElBytes = getResultedTypeSize(EltType, DL);
36373640
switch (ElBytes) {
36383641
// For N = 2 byte data type, use block size 1 and block count 2.
36393642
// Otherwise, use block size N and block count 1.

IGC/VectorCompiler/lib/GenXCodeGen/GenXThreadPrivateMemory.cpp

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ IN THE SOFTWARE.
3535
#include "GenXTargetMachine.h"
3636
#include "GenXUtil.h"
3737
#include "GenXVisa.h"
38+
#include "vc/GenXCodeGen/GenXInternalMetadata.h"
3839

3940
#include "Probe/Assertion.h"
4041
#include "llvmWrapper/IR/DerivedTypes.h"
@@ -208,6 +209,7 @@ std::pair<Value *, unsigned>
208209
GenXThreadPrivateMemory::NormalizeVector(Value *From, Type *To,
209210
Instruction *Inst) {
210211
Type *I32Ty = Type::getInt32Ty(Inst->getContext());
212+
Type *I64Ty = Type::getInt64Ty(Inst->getContext());
211213
Value *Res = From;
212214
Type *FromTy = From->getType();
213215
IGC_ASSERT(isa<VectorType>(FromTy));
@@ -234,22 +236,22 @@ GenXThreadPrivateMemory::NormalizeVector(Value *From, Type *To,
234236
To = IGCLLVM::FixedVectorType::get(I32Ty, NumElts);
235237
EltSz = I32Ty->getPrimitiveSizeInBits() / genx::ByteBits;
236238
Res = CastInst::Create(Instruction::BitCast, Res, To, "", Inst);
237-
} else if (cast<VectorType>(To)->getElementType()->getPrimitiveSizeInBits() <
238-
genx::DWordBits
239-
// this is required for correct generation of svm.gather/scatter
240-
// of data of type which size is < i32 because these intrinsics
241-
// infer their block size from the type of the data they handle
242-
&& !m_useGlobalMem) {
239+
} else if (m_DL->getTypeSizeInBits(cast<VectorType>(To)->getElementType()) <
240+
genx::DWordBits) {
243241
To = IGCLLVM::FixedVectorType::get(I32Ty, NumElts);
244-
245-
Res = CastInst::Create(Instruction::ZExt, From, To, "", Inst);
246-
} else if (cast<VectorType>(To)->getElementType()->getPrimitiveSizeInBits() ==
247-
genx::QWordBits) {
242+
Res = CastInst::CreateZExtOrBitCast(From, To, "", Inst);
243+
} else if (!m_useGlobalMem &&
244+
m_DL->getTypeSizeInBits(cast<VectorType>(To)->getElementType()) ==
245+
genx::QWordBits) {
246+
if (From->getType()->getScalarType()->isPointerTy()) {
247+
auto *NewType = IGCLLVM::FixedVectorType::get(I64Ty, NumElts);
248+
From = CastInst::Create(CastInst::PtrToInt, From, NewType, "", Inst);
249+
}
248250
NumElts *= 2;
249251
EltSz = I32Ty->getPrimitiveSizeInBits() / genx::ByteBits;
250252
To = IGCLLVM::FixedVectorType::get(I32Ty, NumElts);
251253

252-
Res = CastInst::Create(Instruction::BitCast, From, To, "", Inst);
254+
Res = CastInst::CreateBitOrPointerCast(From, To, "", Inst);
253255
}
254256

255257
return std::make_pair(Res, EltSz);
@@ -258,6 +260,8 @@ GenXThreadPrivateMemory::NormalizeVector(Value *From, Type *To,
258260
Instruction *
259261
GenXThreadPrivateMemory::RestoreVectorAfterNormalization(Instruction *From,
260262
Type *To) {
263+
if (From->getType() == To)
264+
return From;
261265
Instruction *Restored = From;
262266
unsigned EltSz = m_DL->getTypeSizeInBits(To->getScalarType());
263267
IGC_ASSERT(EltSz > 0);
@@ -519,35 +523,19 @@ bool GenXThreadPrivateMemory::replaceLoad(LoadInst *LdI) {
519523
LdTy = IGCLLVM::FixedVectorType::get(LdTy, 1);
520524

521525
unsigned NumEltsToLoad = cast<VectorType>(LdTy)->getNumElements();
522-
unsigned LdEltTySz = m_DL->getTypeSizeInBits(LdEltTy);
523-
if (!(m_useGlobalMem && LdEltTy->isIntegerTy(64)) &&
524-
LdEltTySz == genx::QWordBits)
525-
NumEltsToLoad *= 2;
526+
unsigned ValueEltSz = m_DL->getTypeSizeInBits(LdEltTy) / genx::ByteBits;
526527

527528
Value *PredVal = ConstantInt::get(Type::getInt1Ty(*m_ctx), 1);
528529
Value *Pred = Builder.CreateVectorSplat(NumEltsToLoad, PredVal);
529530

530531
Type *I32Ty = Type::getInt32Ty(*m_ctx);
531532
Type *I64Ty = Type::getInt64Ty(*m_ctx);
532-
Type *TyToLoad = (m_useGlobalMem && LdEltTy->isIntegerTy(64)) ? I64Ty : I32Ty;
533-
if (LdEltTy->isFloatTy())
534-
TyToLoad = LdEltTy;
535-
Type *RealTyToLoad = LdEltTy;
536-
if (!(m_useGlobalMem && LdEltTy->isIntegerTy(64)) &&
537-
m_DL->getTypeSizeInBits(RealTyToLoad) == genx::QWordBits)
538-
RealTyToLoad = I32Ty;
539-
unsigned RealTyToLoadSz =
540-
m_DL->getTypeSizeInBits(RealTyToLoad) / genx::ByteBits;
541-
// we don't want to use improper block sizes for loads of i8/i16
542-
// to make sure we comply with alignment rules for gathers
543-
bool NoExtToDword =
544-
m_useGlobalMem &&
545-
!(LdI->getType()->isAggregateType() || LdI->getType()->isVectorTy()) &&
546-
m_DL->getTypeSizeInBits(LdI->getType()) < genx::DWordBits;
547-
if (NoExtToDword)
548-
TyToLoad = LdI->getType();
549533
Value *OldValOfTheDataRead =
550-
Builder.CreateVectorSplat(NumEltsToLoad, UndefValue::get(TyToLoad));
534+
Builder.CreateVectorSplat(NumEltsToLoad, UndefValue::get(LdEltTy));
535+
std::tie(OldValOfTheDataRead, ValueEltSz) =
536+
NormalizeVector(OldValOfTheDataRead, LdTy, LdI);
537+
NumEltsToLoad =
538+
cast<VectorType>(OldValOfTheDataRead->getType())->getNumElements();
551539

552540
Value *PointerOp = LdI->getPointerOperand();
553541
Value *Offset = lookForPtrReplacement(PointerOp);
@@ -557,10 +545,13 @@ bool GenXThreadPrivateMemory::replaceLoad(LoadInst *LdI) {
557545
? llvm::GenXIntrinsic::genx_svm_gather
558546
: llvm::GenXIntrinsic::genx_gather_scaled;
559547

560-
Value *EltsOffset = FormEltsOffsetVector(NumEltsToLoad, RealTyToLoadSz, LdI);
548+
Value *EltsOffset = FormEltsOffsetVector(NumEltsToLoad, ValueEltSz, LdI);
561549

562-
unsigned SrcSize = genx::log2(RealTyToLoadSz);
563-
Value *logNumBlocks = ConstantInt::get(I32Ty, m_useGlobalMem ? 0 : SrcSize);
550+
unsigned NumBlocks = m_DL->getTypeSizeInBits(LdEltTy) / genx::ByteBits;
551+
// This logic is aligned with the on in CisaBuilder and GenXLowering
552+
// The reason behind check for == 2 is that svm intrinsics don't support
553+
// BlockSize of 2, so for ops with i16s we have to use BlockSize == 1 and NumBlocks == 2
554+
Value *logNumBlocks = ConstantInt::get(I32Ty, genx::log2(NumBlocks == 2 ? NumBlocks : 1));
564555
Value *Scale = ConstantInt::get(Type::getInt16Ty(*m_ctx), 0);
565556
Value *Surface = ConstantInt::get(I32Ty,
566557
visa::getReservedSurfaceIndex(m_stack));
@@ -601,6 +592,10 @@ bool GenXThreadPrivateMemory::replaceLoad(LoadInst *LdI) {
601592
ProperGather = LdVal;
602593
}
603594

595+
Gather->setMetadata(InstMD::SVMBlockType,
596+
MDNode::get(*m_ctx, llvm::ValueAsMetadata::get(
597+
UndefValue::get(LdEltTy))));
598+
604599
LLVM_DEBUG(dbgs() << *Gather << "\n");
605600
LdI->replaceAllUsesWith(ProperGather);
606601
LdI->eraseFromParent();
@@ -647,7 +642,9 @@ bool GenXThreadPrivateMemory::replaceStore(StoreInst *StI) {
647642
{Pred->getType(),
648643
(m_useGlobalMem ? Offset : EltsOffset)->getType(),
649644
ValueOp->getType()});
650-
Value *logNumBlocks = ConstantInt::get(I32Ty, m_useGlobalMem ? 0 : genx::log2(ValueEltSz));
645+
unsigned NumBlocks = m_DL->getTypeSizeInBits(ValueOpTy->getScalarType()) / genx::ByteBits;
646+
// see the comment in replaceLoad above
647+
Value *logNumBlocks = ConstantInt::get(I32Ty, genx::log2(NumBlocks == 2 ? NumBlocks : 1));
651648
Value *Scale = ConstantInt::get(Type::getInt16Ty(*m_ctx), 0);
652649
Value *Surface = ConstantInt::get(I32Ty,
653650
visa::getReservedSurfaceIndex(m_stack));
@@ -662,6 +659,11 @@ bool GenXThreadPrivateMemory::replaceStore(StoreInst *StI) {
662659
Scatter->insertAfter(StI);
663660
StI->eraseFromParent();
664661

662+
Scatter->setMetadata(
663+
InstMD::SVMBlockType,
664+
MDNode::get(*m_ctx, llvm::ValueAsMetadata::get(
665+
UndefValue::get(ValueOpTy->getScalarType()))));
666+
665667
LLVM_DEBUG(dbgs() << *Scatter << "\n");
666668
m_scatter.push_back(Scatter);
667669

@@ -1094,6 +1096,12 @@ void SplitScatter(CallInst *CI) {
10941096
}
10951097
IGC_ASSERT(FirstScatter && SecondScatter);
10961098

1099+
auto *MD = CI->getMetadata(InstMD::SVMBlockType);
1100+
if (MD) {
1101+
FirstScatter->setMetadata(InstMD::SVMBlockType, MD);
1102+
SecondScatter->setMetadata(InstMD::SVMBlockType, MD);
1103+
}
1104+
10971105
FirstScatter->insertAfter(CI);
10981106
SecondScatter->insertAfter(FirstScatter);
10991107

@@ -1163,6 +1171,12 @@ void SplitGather(CallInst *CI) {
11631171
}
11641172
IGC_ASSERT(FirstGather && SecondGather);
11651173

1174+
auto *MD = CI->getMetadata(InstMD::SVMBlockType);
1175+
if (MD) {
1176+
FirstGather->setMetadata(InstMD::SVMBlockType, MD);
1177+
SecondGather->setMetadata(InstMD::SVMBlockType, MD);
1178+
}
1179+
11661180
FirstGather->insertAfter(CI);
11671181
SecondGather->insertAfter(FirstGather);
11681182

@@ -1280,14 +1294,16 @@ bool GenXThreadPrivateMemory::runOnModule(Module &M) {
12801294
m_ST = &getAnalysis<TargetPassConfig>()
12811295
.getTM<GenXTargetMachine>()
12821296
.getGenXSubtarget();
1297+
if (!m_ST->isOCLRuntime())
1298+
m_useGlobalMem = false;
12831299
for (auto &F : M)
12841300
visit(F);
1285-
if (!m_useGlobalMem &&
1286-
std::find_if(m_alloca.begin(), m_alloca.end(), SVMChecker()) !=
1287-
m_alloca.end()) {
1301+
if (m_useGlobalMem ||
1302+
(m_ST->isOCLRuntime() && std::find_if(m_alloca.begin(), m_alloca.end(),
1303+
SVMChecker()) != m_alloca.end())) {
12881304
LLVM_DEBUG(dbgs() << "Switching TPM to SVM\n");
12891305
// TODO: move the name string to vc-intrinsics *MD::useGlobalMem
1290-
M.addModuleFlag(Module::ModFlagBehavior::Error, "genx.useGlobalMem", 1);
1306+
M.addModuleFlag(Module::ModFlagBehavior::Error, ModuleMD::UseSVMStack, 1);
12911307
m_useGlobalMem = true;
12921308
}
12931309
bool Result = false;

0 commit comments

Comments
 (0)