Skip to content

Commit ca4bc37

Browse files
jrbyrnesbcahoon
authored andcommitted
[AMDGPU] Optionally Use GCNRPTrackers during scheduling (llvm#93090)
This adds the ability to use the GCNRPTrackers during scheduling. These trackers have several advantages over the generic trackers: 1. global live-thru trackers, 2. subregister based RP deltas, and 3. flexible vreg -> PressureSet mappings. This feature is off-by-default to ease with the roll-out process. In particular, when using the optional trackers, the scheduler will still maintain the generic trackers leading to unnecessary compile time. (cherry picked from commit 17bc959) Change-Id: Ia45099682b69df7113b7736f052e2bde6f926ca4
1 parent 90a99ab commit ca4bc37

12 files changed

+1673
-79
lines changed

llvm/lib/Target/AMDGPU/GCNIterativeScheduler.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ void GCNIterativeScheduler::scheduleLegacyMaxOccupancy(
480480
LLVM_DEBUG(dbgs() << "Scheduling using default scheduler, "
481481
"target occupancy = "
482482
<< TgtOcc << '\n');
483-
GCNMaxOccupancySchedStrategy LStrgy(Context);
483+
GCNMaxOccupancySchedStrategy LStrgy(Context, /*IsLegacyScheduler=*/true);
484484
unsigned FinalOccupancy = std::min(Occ, MFI->getOccupancy());
485485

486486
for (int I = 0; I < NumPasses; ++I) {

llvm/lib/Target/AMDGPU/GCNRegPressure.cpp

Lines changed: 176 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,63 @@ collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
295295
}
296296
}
297297

298+
/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
299+
static LaneBitmask getLanesWithProperty(
300+
const LiveIntervals &LIS, const MachineRegisterInfo &MRI,
301+
bool TrackLaneMasks, Register RegUnit, SlotIndex Pos,
302+
LaneBitmask SafeDefault,
303+
function_ref<bool(const LiveRange &LR, SlotIndex Pos)> Property) {
304+
if (RegUnit.isVirtual()) {
305+
const LiveInterval &LI = LIS.getInterval(RegUnit);
306+
LaneBitmask Result;
307+
if (TrackLaneMasks && LI.hasSubRanges()) {
308+
for (const LiveInterval::SubRange &SR : LI.subranges()) {
309+
if (Property(SR, Pos))
310+
Result |= SR.LaneMask;
311+
}
312+
} else if (Property(LI, Pos)) {
313+
Result = TrackLaneMasks ? MRI.getMaxLaneMaskForVReg(RegUnit)
314+
: LaneBitmask::getAll();
315+
}
316+
317+
return Result;
318+
}
319+
320+
const LiveRange *LR = LIS.getCachedRegUnit(RegUnit);
321+
if (LR == nullptr)
322+
return SafeDefault;
323+
return Property(*LR, Pos) ? LaneBitmask::getAll() : LaneBitmask::getNone();
324+
}
325+
326+
/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
327+
/// Helper to find a vreg use between two indices {PriorUseIdx, NextUseIdx}.
328+
/// The query starts with a lane bitmask which gets lanes/bits removed for every
329+
/// use we find.
330+
static LaneBitmask findUseBetween(unsigned Reg, LaneBitmask LastUseMask,
331+
SlotIndex PriorUseIdx, SlotIndex NextUseIdx,
332+
const MachineRegisterInfo &MRI,
333+
const SIRegisterInfo *TRI,
334+
const LiveIntervals *LIS,
335+
bool Upward = false) {
336+
for (const MachineOperand &MO : MRI.use_nodbg_operands(Reg)) {
337+
if (MO.isUndef())
338+
continue;
339+
const MachineInstr *MI = MO.getParent();
340+
SlotIndex InstSlot = LIS->getInstructionIndex(*MI).getRegSlot();
341+
bool InRange = Upward ? (InstSlot > PriorUseIdx && InstSlot <= NextUseIdx)
342+
: (InstSlot >= PriorUseIdx && InstSlot < NextUseIdx);
343+
if (!InRange)
344+
continue;
345+
346+
unsigned SubRegIdx = MO.getSubReg();
347+
LaneBitmask UseMask = TRI->getSubRegIndexLaneMask(SubRegIdx);
348+
LastUseMask &= ~UseMask;
349+
if (LastUseMask.none())
350+
return LaneBitmask::getNone();
351+
}
352+
return LastUseMask;
353+
}
354+
298355
///////////////////////////////////////////////////////////////////////////////
299356
// GCNRPTracker
300357

@@ -353,17 +410,28 @@ void GCNRPTracker::reset(const MachineInstr &MI,
353410
MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
354411
}
355412

356-
////////////////////////////////////////////////////////////////////////////////
357-
// GCNUpwardRPTracker
358-
359-
void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
360-
const LiveRegSet &LiveRegs_) {
413+
void GCNRPTracker::reset(const MachineRegisterInfo &MRI_,
414+
const LiveRegSet &LiveRegs_) {
361415
MRI = &MRI_;
362416
LiveRegs = LiveRegs_;
363417
LastTrackedMI = nullptr;
364418
MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
365419
}
366420

421+
/// Mostly copy/paste from CodeGen/RegisterPressure.cpp
422+
LaneBitmask GCNRPTracker::getLastUsedLanes(Register RegUnit,
423+
SlotIndex Pos) const {
424+
return getLanesWithProperty(
425+
LIS, *MRI, true, RegUnit, Pos.getBaseIndex(), LaneBitmask::getNone(),
426+
[](const LiveRange &LR, SlotIndex Pos) {
427+
const LiveRange::Segment *S = LR.getSegmentContaining(Pos);
428+
return S != nullptr && S->end == Pos.getRegSlot();
429+
});
430+
}
431+
432+
////////////////////////////////////////////////////////////////////////////////
433+
// GCNUpwardRPTracker
434+
367435
void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
368436
assert(MRI && "call reset first");
369437

@@ -440,25 +508,37 @@ bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
440508
return true;
441509
}
442510

443-
bool GCNDownwardRPTracker::advanceBeforeNext() {
511+
bool GCNDownwardRPTracker::advanceBeforeNext(MachineInstr *MI,
512+
bool UseInternalIterator) {
444513
assert(MRI && "call reset first");
445-
if (!LastTrackedMI)
446-
return NextMI == MBBEnd;
447-
448-
assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
514+
SlotIndex SI;
515+
const MachineInstr *CurrMI;
516+
if (UseInternalIterator) {
517+
if (!LastTrackedMI)
518+
return NextMI == MBBEnd;
519+
520+
assert(NextMI == MBBEnd || !NextMI->isDebugInstr());
521+
CurrMI = LastTrackedMI;
522+
523+
SI = NextMI == MBBEnd
524+
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
525+
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
526+
} else { //! UseInternalIterator
527+
SI = LIS.getInstructionIndex(*MI).getBaseIndex();
528+
CurrMI = MI;
529+
}
449530

450-
SlotIndex SI = NextMI == MBBEnd
451-
? LIS.getInstructionIndex(*LastTrackedMI).getDeadSlot()
452-
: LIS.getInstructionIndex(*NextMI).getBaseIndex();
453531
assert(SI.isValid());
454532

455533
// Remove dead registers or mask bits.
456534
SmallSet<Register, 8> SeenRegs;
457-
for (auto &MO : LastTrackedMI->operands()) {
535+
for (auto &MO : CurrMI->operands()) {
458536
if (!MO.isReg() || !MO.getReg().isVirtual())
459537
continue;
460538
if (MO.isUse() && !MO.readsReg())
461539
continue;
540+
if (!UseInternalIterator && MO.isDef())
541+
continue;
462542
if (!SeenRegs.insert(MO.getReg()).second)
463543
continue;
464544
const LiveInterval &LI = LIS.getInterval(MO.getReg());
@@ -491,15 +571,22 @@ bool GCNDownwardRPTracker::advanceBeforeNext() {
491571

492572
LastTrackedMI = nullptr;
493573

494-
return NextMI == MBBEnd;
574+
return UseInternalIterator && (NextMI == MBBEnd);
495575
}
496576

497-
void GCNDownwardRPTracker::advanceToNext() {
498-
LastTrackedMI = &*NextMI++;
499-
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
577+
void GCNDownwardRPTracker::advanceToNext(MachineInstr *MI,
578+
bool UseInternalIterator) {
579+
if (UseInternalIterator) {
580+
LastTrackedMI = &*NextMI++;
581+
NextMI = skipDebugInstructionsForward(NextMI, MBBEnd);
582+
} else {
583+
LastTrackedMI = MI;
584+
}
585+
586+
const MachineInstr *CurrMI = LastTrackedMI;
500587

501588
// Add new registers or mask bits.
502-
for (const auto &MO : LastTrackedMI->all_defs()) {
589+
for (const auto &MO : CurrMI->all_defs()) {
503590
Register Reg = MO.getReg();
504591
if (!Reg.isVirtual())
505592
continue;
@@ -512,11 +599,16 @@ void GCNDownwardRPTracker::advanceToNext() {
512599
MaxPressure = max(MaxPressure, CurPressure);
513600
}
514601

515-
bool GCNDownwardRPTracker::advance() {
516-
if (NextMI == MBBEnd)
602+
bool GCNDownwardRPTracker::advance(MachineInstr *MI, bool UseInternalIterator) {
603+
if (UseInternalIterator && NextMI == MBBEnd)
517604
return false;
518-
advanceBeforeNext();
519-
advanceToNext();
605+
606+
advanceBeforeNext(MI, UseInternalIterator);
607+
advanceToNext(MI, UseInternalIterator);
608+
if (!UseInternalIterator) {
609+
// We must remove any dead def lanes from the current RP
610+
advanceBeforeNext(MI, true);
611+
}
520612
return true;
521613
}
522614

@@ -558,6 +650,67 @@ Printable llvm::reportMismatch(const GCNRPTracker::LiveRegSet &LISLR,
558650
});
559651
}
560652

653+
GCNRegPressure
654+
GCNDownwardRPTracker::bumpDownwardPressure(const MachineInstr *MI,
655+
const SIRegisterInfo *TRI) const {
656+
assert(!MI->isDebugOrPseudoInstr() && "Expect a nondebug instruction.");
657+
658+
SlotIndex SlotIdx;
659+
SlotIdx = LIS.getInstructionIndex(*MI).getRegSlot();
660+
661+
// Account for register pressure similar to RegPressureTracker::recede().
662+
RegisterOperands RegOpers;
663+
RegOpers.collect(*MI, *TRI, *MRI, true, /*IgnoreDead=*/false);
664+
RegOpers.adjustLaneLiveness(LIS, *MRI, SlotIdx);
665+
GCNRegPressure TempPressure = CurPressure;
666+
667+
for (const RegisterMaskPair &Use : RegOpers.Uses) {
668+
Register Reg = Use.RegUnit;
669+
if (!Reg.isVirtual())
670+
continue;
671+
LaneBitmask LastUseMask = getLastUsedLanes(Reg, SlotIdx);
672+
if (LastUseMask.none())
673+
continue;
674+
// The LastUseMask is queried from the liveness information of instruction
675+
// which may be further down the schedule. Some lanes may actually not be
676+
// last uses for the current position.
677+
// FIXME: allow the caller to pass in the list of vreg uses that remain
678+
// to be bottom-scheduled to avoid searching uses at each query.
679+
SlotIndex CurrIdx;
680+
const MachineBasicBlock *MBB = MI->getParent();
681+
MachineBasicBlock::const_iterator IdxPos = skipDebugInstructionsForward(
682+
LastTrackedMI ? LastTrackedMI : MBB->begin(), MBB->end());
683+
if (IdxPos == MBB->end()) {
684+
CurrIdx = LIS.getMBBEndIdx(MBB);
685+
} else {
686+
CurrIdx = LIS.getInstructionIndex(*IdxPos).getRegSlot();
687+
}
688+
689+
LastUseMask =
690+
findUseBetween(Reg, LastUseMask, CurrIdx, SlotIdx, *MRI, TRI, &LIS);
691+
if (LastUseMask.none())
692+
continue;
693+
694+
LaneBitmask LiveMask =
695+
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
696+
LaneBitmask NewMask = LiveMask & ~LastUseMask;
697+
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
698+
}
699+
700+
// Generate liveness for defs.
701+
for (const RegisterMaskPair &Def : RegOpers.Defs) {
702+
Register Reg = Def.RegUnit;
703+
if (!Reg.isVirtual())
704+
continue;
705+
LaneBitmask LiveMask =
706+
LiveRegs.contains(Reg) ? LiveRegs.at(Reg) : LaneBitmask(0);
707+
LaneBitmask NewMask = LiveMask | Def.LaneMask;
708+
TempPressure.inc(Reg, LiveMask, NewMask, *MRI);
709+
}
710+
711+
return TempPressure;
712+
}
713+
561714
bool GCNUpwardRPTracker::isValid() const {
562715
const auto &SI = LIS.getInstructionIndex(*LastTrackedMI).getBaseIndex();
563716
const auto LISLR = llvm::getLiveRegs(SI, LIS, *MRI);

0 commit comments

Comments
 (0)