Skip to content

Commit 1c502bc

Browse files
committed
[AMDGPU] Add IR LiveReg type-based optimization
Change-Id: I90c3d13d69425d6a2b9dbbeadb3a414983559667
1 parent a7b5122 commit 1c502bc

File tree

3 files changed

+428
-60
lines changed

3 files changed

+428
-60
lines changed

llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class AMDGPUCodeGenPrepareImpl
106106
Module *Mod = nullptr;
107107
const DataLayout *DL = nullptr;
108108
bool HasUnsafeFPMath = false;
109+
bool UsesGlobalISel = false;
109110
bool HasFP32DenormalFlush = false;
110111
bool FlowChanged = false;
111112
mutable Function *SqrtF32 = nullptr;
@@ -341,6 +342,85 @@ class AMDGPUCodeGenPrepare : public FunctionPass {
341342
StringRef getPassName() const override { return "AMDGPU IR optimizations"; }
342343
};
343344

345+
class LiveRegConversion {
346+
private:
347+
// The instruction which defined the original virtual register used across
348+
// blocks
349+
Instruction *LiveRegDef;
350+
// The original type
351+
Type *OriginalType;
352+
// The desired type
353+
Type *NewType;
354+
// The instruction sequence that converts the virtual register, to be used
355+
// instead of the original
356+
std::optional<Instruction *> Converted;
357+
// The builder used to build the conversion instruction
358+
IRBuilder<> ConvertBuilder;
359+
360+
public:
361+
// The instruction which defined the original virtual register used across
362+
// blocks
363+
Instruction *getLiveRegDef() { return LiveRegDef; }
364+
// The original type
365+
Type *getOriginalType() { return OriginalType; }
366+
// The desired type
367+
Type *getNewType() { return NewType; }
368+
void setNewType(Type *NewType) { this->NewType = NewType; }
369+
// The instruction that conerts the virtual register, to be used instead of
370+
// the original
371+
std::optional<Instruction *> &getConverted() { return Converted; }
372+
void setConverted(Instruction *Converted) { this->Converted = Converted; }
373+
// The builder used to build the conversion instruction
374+
IRBuilder<> &getConverBuilder() { return ConvertBuilder; }
375+
// Do we have a instruction sequence which convert the original virtual
376+
// register
377+
bool hasConverted() { return Converted.has_value(); }
378+
379+
LiveRegConversion(Instruction *LiveRegDef, BasicBlock *InsertBlock,
380+
BasicBlock::iterator InsertPt)
381+
: LiveRegDef(LiveRegDef), OriginalType(LiveRegDef->getType()),
382+
ConvertBuilder(InsertBlock, InsertPt) {}
383+
LiveRegConversion(Instruction *LiveRegDef, Type *NewType,
384+
BasicBlock *InsertBlock, BasicBlock::iterator InsertPt)
385+
: LiveRegDef(LiveRegDef), OriginalType(LiveRegDef->getType()),
386+
NewType(NewType), ConvertBuilder(InsertBlock, InsertPt) {}
387+
};
388+
389+
class LiveRegOptimizer {
390+
private:
391+
Module *Mod = nullptr;
392+
// The scalar type to convert to
393+
Type *ConvertToScalar;
394+
// Holds the collection of PHIs with their pending new operands
395+
SmallVector<std::pair<Instruction *,
396+
SmallVector<std::pair<Instruction *, BasicBlock *>, 4>>,
397+
4>
398+
PHIUpdater;
399+
400+
public:
401+
// Should the def of the instruction be converted if it is live across blocks
402+
bool shouldReplaceUses(const Instruction &I);
403+
// Convert the virtual register to the compatible vector of legal type
404+
void convertToOptType(LiveRegConversion &LR);
405+
// Convert the virtual register back to the original type, stripping away
406+
// the MSBs in cases where there was an imperfect fit (e.g. v2i32 -> v7i8)
407+
void convertFromOptType(LiveRegConversion &LR);
408+
// Get a vector of desired scalar type that is compatible with the original
409+
// vector. In cases where there is no bitsize equivalent using a legal vector
410+
// type, we pad the MSBs (e.g. v7i8 -> v2i32)
411+
Type *getCompatibleType(Instruction *InstToConvert);
412+
// Find and replace uses of the virtual register in different block with a
413+
// newly produced virtual register of legal type
414+
bool replaceUses(Instruction &I);
415+
// Replace the collected PHIs with newly produced incoming values. Replacement
416+
// is only done if we have a replacement for each original incoming value.
417+
bool replacePHIs();
418+
419+
LiveRegOptimizer(Module *Mod) : Mod(Mod) {
420+
ConvertToScalar = Type::getInt32Ty(Mod->getContext());
421+
}
422+
};
423+
344424
} // end anonymous namespace
345425

346426
bool AMDGPUCodeGenPrepareImpl::run(Function &F) {
@@ -358,6 +438,7 @@ bool AMDGPUCodeGenPrepareImpl::run(Function &F) {
358438
Next = std::next(I);
359439

360440
MadeChange |= visit(*I);
441+
I->getType();
361442

362443
if (Next != E) { // Control flow changed
363444
BasicBlock *NextInstBB = Next->getParent();
@@ -369,9 +450,269 @@ bool AMDGPUCodeGenPrepareImpl::run(Function &F) {
369450
}
370451
}
371452
}
453+
454+
// GlobalISel should directly use the values, and do not need to emit
455+
// CopyTo/CopyFrom Regs across blocks
456+
if (UsesGlobalISel)
457+
return MadeChange;
458+
459+
// "Optimize" the virtual regs that cross basic block boundaries. In such
460+
// cases, vectors of illegal types will be scalarized and widened, with each
461+
// scalar living in its own physical register. The optimization converts the
462+
// vectors to equivalent vectors of legal type (which are convereted back
463+
// before uses in subsequenmt blocks), to pack the bits into fewer physical
464+
// registers (used in CopyToReg/CopyFromReg pairs).
465+
LiveRegOptimizer LRO(Mod);
466+
for (auto &BB : F) {
467+
for (auto &I : BB) {
468+
if (!LRO.shouldReplaceUses(I))
469+
continue;
470+
MadeChange |= LRO.replaceUses(I);
471+
}
472+
}
473+
474+
MadeChange |= LRO.replacePHIs();
475+
return MadeChange;
476+
}
477+
478+
bool LiveRegOptimizer::replaceUses(Instruction &I) {
479+
bool MadeChange = false;
480+
481+
struct ConvertUseInfo {
482+
Instruction *Converted;
483+
SmallVector<Instruction *, 4> Users;
484+
};
485+
DenseMap<BasicBlock *, ConvertUseInfo> UseConvertTracker;
486+
487+
LiveRegConversion FromLRC(
488+
&I, I.getParent(),
489+
static_cast<BasicBlock::iterator>(std::next(I.getIterator())));
490+
FromLRC.setNewType(getCompatibleType(FromLRC.getLiveRegDef()));
491+
for (auto IUser = I.user_begin(); IUser != I.user_end(); IUser++) {
492+
493+
if (auto UserInst = dyn_cast<Instruction>(*IUser)) {
494+
if (UserInst->getParent() != I.getParent()) {
495+
LLVM_DEBUG(dbgs() << *UserInst << "\n\tUses "
496+
<< *FromLRC.getOriginalType()
497+
<< " from previous block. Needs conversion\n");
498+
convertToOptType(FromLRC);
499+
if (!FromLRC.hasConverted())
500+
continue;
501+
// If it is a PHI node, just create and collect the new operand. We can
502+
// only replace the PHI node once we have converted all the operands
503+
if (auto PhiInst = dyn_cast<PHINode>(UserInst)) {
504+
for (unsigned Idx = 0; Idx < PhiInst->getNumIncomingValues(); Idx++) {
505+
auto IncVal = PhiInst->getIncomingValue(Idx);
506+
if (&I == dyn_cast<Instruction>(IncVal)) {
507+
auto IncBlock = PhiInst->getIncomingBlock(Idx);
508+
auto PHIOps = find_if(
509+
PHIUpdater,
510+
[&UserInst](
511+
std::pair<Instruction *,
512+
SmallVector<
513+
std::pair<Instruction *, BasicBlock *>, 4>>
514+
&Entry) { return Entry.first == UserInst; });
515+
516+
if (PHIOps == PHIUpdater.end())
517+
PHIUpdater.push_back(
518+
{UserInst, {{*FromLRC.getConverted(), IncBlock}}});
519+
else
520+
PHIOps->second.push_back({*FromLRC.getConverted(), IncBlock});
521+
522+
break;
523+
}
524+
}
525+
continue;
526+
}
527+
528+
// Do not create multiple conversion sequences if there are multiple
529+
// uses in the same block
530+
if (UseConvertTracker.contains(UserInst->getParent())) {
531+
UseConvertTracker[UserInst->getParent()].Users.push_back(UserInst);
532+
LLVM_DEBUG(dbgs() << "\tUser already has access to converted def\n");
533+
continue;
534+
}
535+
536+
LiveRegConversion ToLRC(*FromLRC.getConverted(), I.getType(),
537+
UserInst->getParent(),
538+
static_cast<BasicBlock::iterator>(
539+
UserInst->getParent()->getFirstNonPHIIt()));
540+
convertFromOptType(ToLRC);
541+
assert(ToLRC.hasConverted());
542+
UseConvertTracker[UserInst->getParent()] = {*ToLRC.getConverted(),
543+
{UserInst}};
544+
}
545+
}
546+
}
547+
548+
// Replace uses of with in a separate loop that is not dependent upon the
549+
// state of the uses
550+
for (auto &Entry : UseConvertTracker) {
551+
for (auto &UserInst : Entry.second.Users) {
552+
LLVM_DEBUG(dbgs() << *UserInst
553+
<< "\n\tNow uses: " << *Entry.second.Converted << "\n");
554+
UserInst->replaceUsesOfWith(&I, Entry.second.Converted);
555+
MadeChange = true;
556+
}
557+
}
558+
return MadeChange;
559+
}
560+
561+
bool LiveRegOptimizer::replacePHIs() {
562+
bool MadeChange = false;
563+
for (auto Ele : PHIUpdater) {
564+
auto ThePHINode = dyn_cast<PHINode>(Ele.first);
565+
assert(ThePHINode);
566+
auto NewPHINodeOps = Ele.second;
567+
LLVM_DEBUG(dbgs() << "Attempting to replace: " << *ThePHINode << "\n");
568+
// If we have conveted all the required operands, then do the replacement
569+
if (ThePHINode->getNumIncomingValues() == NewPHINodeOps.size()) {
570+
IRBuilder<> Builder(Ele.first);
571+
auto NPHI = Builder.CreatePHI(NewPHINodeOps[0].first->getType(),
572+
NewPHINodeOps.size());
573+
for (auto IncVals : NewPHINodeOps) {
574+
NPHI->addIncoming(IncVals.first, IncVals.second);
575+
LLVM_DEBUG(dbgs() << " Using: " << *IncVals.first
576+
<< " For: " << IncVals.second->getName() << "\n");
577+
}
578+
LLVM_DEBUG(dbgs() << "Sucessfully replaced with " << *NPHI << "\n");
579+
LiveRegConversion ToLRC(NPHI, ThePHINode->getType(),
580+
ThePHINode->getParent(),
581+
static_cast<BasicBlock::iterator>(
582+
ThePHINode->getParent()->getFirstNonPHIIt()));
583+
convertFromOptType(ToLRC);
584+
assert(ToLRC.hasConverted());
585+
Ele.first->replaceAllUsesWith(*ToLRC.getConverted());
586+
// The old PHI is no longer used
587+
ThePHINode->eraseFromParent();
588+
MadeChange = true;
589+
}
590+
}
372591
return MadeChange;
373592
}
374593

594+
Type *LiveRegOptimizer::getCompatibleType(Instruction *InstToConvert) {
595+
auto OriginalType = InstToConvert->getType();
596+
assert(OriginalType->getScalarSizeInBits() <=
597+
ConvertToScalar->getScalarSizeInBits());
598+
auto VTy = dyn_cast<VectorType>(OriginalType);
599+
if (!VTy)
600+
return ConvertToScalar;
601+
602+
auto OriginalSize =
603+
VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
604+
auto ConvertScalarSize = ConvertToScalar->getScalarSizeInBits();
605+
auto ConvertEltCount =
606+
(OriginalSize + ConvertScalarSize - 1) / ConvertScalarSize;
607+
608+
return VectorType::get(Type::getIntNTy(Mod->getContext(), ConvertScalarSize),
609+
llvm::ElementCount::getFixed(ConvertEltCount));
610+
}
611+
612+
void LiveRegOptimizer::convertToOptType(LiveRegConversion &LR) {
613+
if (LR.hasConverted()) {
614+
LLVM_DEBUG(dbgs() << "\tAlready has converted def\n");
615+
return;
616+
}
617+
618+
auto VTy = dyn_cast<VectorType>(LR.getOriginalType());
619+
assert(VTy);
620+
auto NewVTy = dyn_cast<VectorType>(LR.getNewType());
621+
assert(NewVTy);
622+
623+
auto V = static_cast<Value *>(LR.getLiveRegDef());
624+
auto OriginalSize =
625+
VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
626+
auto NewSize =
627+
NewVTy->getScalarSizeInBits() * NewVTy->getElementCount().getFixedValue();
628+
629+
auto &Builder = LR.getConverBuilder();
630+
631+
// If there is a bitsize match, we can fit the old vector into a new vector of
632+
// desired type
633+
if (OriginalSize == NewSize) {
634+
LR.setConverted(dyn_cast<Instruction>(Builder.CreateBitCast(V, NewVTy)));
635+
LLVM_DEBUG(dbgs() << "\tConverted def to "
636+
<< *(*LR.getConverted())->getType() << "\n");
637+
return;
638+
}
639+
640+
// If there is a bitsize mismatch, we must use a wider vector
641+
assert(NewSize > OriginalSize);
642+
auto ExpandedVecElementCount =
643+
llvm::ElementCount::getFixed(NewSize / VTy->getScalarSizeInBits());
644+
645+
SmallVector<int, 8> ShuffleMask;
646+
for (unsigned I = 0; I < VTy->getElementCount().getFixedValue(); I++)
647+
ShuffleMask.push_back(I);
648+
649+
for (uint64_t I = VTy->getElementCount().getFixedValue();
650+
I < ExpandedVecElementCount.getFixedValue(); I++)
651+
ShuffleMask.push_back(VTy->getElementCount().getFixedValue());
652+
653+
auto ExpandedVec =
654+
dyn_cast<Instruction>(Builder.CreateShuffleVector(V, ShuffleMask));
655+
LR.setConverted(
656+
dyn_cast<Instruction>(Builder.CreateBitCast(ExpandedVec, NewVTy)));
657+
LLVM_DEBUG(dbgs() << "\tConverted def to " << *(*LR.getConverted())->getType()
658+
<< "\n");
659+
return;
660+
}
661+
662+
void LiveRegOptimizer::convertFromOptType(LiveRegConversion &LRC) {
663+
auto VTy = dyn_cast<VectorType>(LRC.getOriginalType());
664+
assert(VTy);
665+
auto NewVTy = dyn_cast<VectorType>(LRC.getNewType());
666+
assert(NewVTy);
667+
668+
auto V = static_cast<Value *>(LRC.getLiveRegDef());
669+
auto OriginalSize =
670+
VTy->getScalarSizeInBits() * VTy->getElementCount().getFixedValue();
671+
auto NewSize =
672+
NewVTy->getScalarSizeInBits() * NewVTy->getElementCount().getFixedValue();
673+
674+
auto &Builder = LRC.getConverBuilder();
675+
676+
// If there is a bitsize match, we simply convert back to the original type
677+
if (OriginalSize == NewSize) {
678+
LRC.setConverted(dyn_cast<Instruction>(Builder.CreateBitCast(V, NewVTy)));
679+
LLVM_DEBUG(dbgs() << "\tProduced for user: " << **LRC.getConverted()
680+
<< "\n");
681+
return;
682+
}
683+
684+
// If there is a bitsize mismatch, we have used a wider vector and must strip
685+
// the MSBs to convert back to the original type
686+
assert(OriginalSize > NewSize);
687+
auto ExpandedVecElementCount = llvm::ElementCount::getFixed(
688+
OriginalSize / NewVTy->getScalarSizeInBits());
689+
auto ExpandedVT = VectorType::get(
690+
Type::getIntNTy(Mod->getContext(), NewVTy->getScalarSizeInBits()),
691+
ExpandedVecElementCount);
692+
auto Converted = dyn_cast<Instruction>(
693+
Builder.CreateBitCast(LRC.getLiveRegDef(), ExpandedVT));
694+
695+
auto NarrowElementCount = NewVTy->getElementCount().getFixedValue();
696+
SmallVector<int, 8> ShuffleMask;
697+
for (uint64_t I = 0; I < NarrowElementCount; I++)
698+
ShuffleMask.push_back(I);
699+
700+
auto NarrowVec = dyn_cast<Instruction>(
701+
Builder.CreateShuffleVector(Converted, ShuffleMask));
702+
LRC.setConverted(dyn_cast<Instruction>(NarrowVec));
703+
LLVM_DEBUG(dbgs() << "\tProduced for user: " << **LRC.getConverted() << "\n");
704+
return;
705+
}
706+
707+
bool LiveRegOptimizer::shouldReplaceUses(const Instruction &I) {
708+
// Vectors of illegal types are copied across blocks in an efficient manner.
709+
// They are scalarized and widened to legal scalars. In such cases, we can do
710+
// better by using legal vector types
711+
auto IType = I.getType();
712+
return IType->isVectorTy() && IType->getScalarSizeInBits() < 16 &&
713+
!I.getType()->getScalarType()->isPointerTy();
714+
}
715+
375716
unsigned AMDGPUCodeGenPrepareImpl::getBaseElementBitWidth(const Type *T) const {
376717
assert(needsPromotionToI32(T) && "T does not need promotion to i32");
377718

@@ -2200,6 +2541,7 @@ bool AMDGPUCodeGenPrepare::runOnFunction(Function &F) {
22002541
Impl.ST = &TM.getSubtarget<GCNSubtarget>(F);
22012542
Impl.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
22022543
Impl.UA = &getAnalysis<UniformityInfoWrapperPass>().getUniformityInfo();
2544+
Impl.UsesGlobalISel = TM.Options.EnableGlobalISel;
22032545
auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
22042546
Impl.DT = DTWP ? &DTWP->getDomTree() : nullptr;
22052547
Impl.HasUnsafeFPMath = hasUnsafeFPMath(F);
@@ -2221,6 +2563,7 @@ PreservedAnalyses AMDGPUCodeGenPreparePass::run(Function &F,
22212563
Impl.DT = FAM.getCachedResult<DominatorTreeAnalysis>(F);
22222564
Impl.HasUnsafeFPMath = hasUnsafeFPMath(F);
22232565
SIModeRegisterDefaults Mode(F, *Impl.ST);
2566+
Impl.UsesGlobalISel = TM.Options.EnableGlobalISel;
22242567
Impl.HasFP32DenormalFlush =
22252568
Mode.FP32Denormals == DenormalMode::getPreserveSign();
22262569
PreservedAnalyses PA = PreservedAnalyses::none();

0 commit comments

Comments
 (0)