Skip to content

Commit 87b8d94

Browse files
authored
[AMDGPU] Fix GCNUpwardRPTracker. (#71186)
Fixed: 1. Maximum register pressure calculation at the instruction level. Previously max RP included both def and use of registers of an instruction. Now maximum RP includes _uses_ and _early-clobber defs_. 2. Uses were incorrectly tracked and this resulted in a mismatch of live-in set reported by LiveIntervals and tracked live reg set when the beginning of the block is reached. Interface has changed, moveMaxPressure becomes deprecated and getMaxPressure, resetMaxPressure functions are added. reset function seem now more consistent.
1 parent bd61126 commit 87b8d94

File tree

4 files changed

+193
-138
lines changed

4 files changed

+193
-138
lines changed

llvm/lib/Target/AMDGPU/GCNIterativeScheduler.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ GCNIterativeScheduler::getRegionPressure(MachineBasicBlock::iterator Begin,
251251
assert(UPTracker.isValid() ||
252252
(dbgs() << "Tracked region ",
253253
printRegion(dbgs(), Begin, End, LIS), false));
254-
return UPTracker.moveMaxPressure();
254+
return UPTracker.getMaxPressureAndReset();
255255
}
256256

257257
// returns max pressure for a tentative schedule
@@ -272,7 +272,7 @@ GCNIterativeScheduler::getSchedulePressure(const Region &R,
272272
for (auto I = Schedule.end(), B = Schedule.begin(); I != B;) {
273273
RPTracker.recede(*getMachineInstr(*--I));
274274
}
275-
return RPTracker.moveMaxPressure();
275+
return RPTracker.getMaxPressureAndReset();
276276
}
277277

278278
void GCNIterativeScheduler::enterRegion(MachineBasicBlock *BB, // overridden

llvm/lib/Target/AMDGPU/GCNRegPressure.cpp

Lines changed: 72 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -166,66 +166,60 @@ static LaneBitmask getDefRegMask(const MachineOperand &MO,
166166
MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(MO.getSubReg());
167167
}
168168

169-
static LaneBitmask getUsedRegMask(const MachineOperand &MO,
170-
const MachineRegisterInfo &MRI,
171-
const LiveIntervals &LIS) {
172-
assert(MO.isUse() && MO.isReg() && MO.getReg().isVirtual());
173-
174-
if (auto SubReg = MO.getSubReg())
175-
return MRI.getTargetRegisterInfo()->getSubRegIndexLaneMask(SubReg);
176-
177-
auto MaxMask = MRI.getMaxLaneMaskForVReg(MO.getReg());
178-
if (SIRegisterInfo::getNumCoveredRegs(MaxMask) > 1) // cannot have subregs
179-
return MaxMask;
180-
181-
// For a tentative schedule LIS isn't updated yet but livemask should remain
182-
// the same on any schedule. Subreg defs can be reordered but they all must
183-
// dominate uses anyway.
184-
auto SI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
185-
return getLiveLaneMask(MO.getReg(), SI, LIS, MRI);
186-
}
187-
188-
static SmallVector<RegisterMaskPair, 8>
189-
collectVirtualRegUses(const MachineInstr &MI, const LiveIntervals &LIS,
169+
static void
170+
collectVirtualRegUses(SmallVectorImpl<RegisterMaskPair> &RegMaskPairs,
171+
const MachineInstr &MI, const LiveIntervals &LIS,
190172
const MachineRegisterInfo &MRI) {
191-
SmallVector<RegisterMaskPair, 8> Res;
173+
SlotIndex InstrSI;
192174
for (const auto &MO : MI.operands()) {
193175
if (!MO.isReg() || !MO.getReg().isVirtual())
194176
continue;
195177
if (!MO.isUse() || !MO.readsReg())
196178
continue;
197179

198-
auto const UsedMask = getUsedRegMask(MO, MRI, LIS);
180+
Register Reg = MO.getReg();
181+
if (llvm::any_of(RegMaskPairs, [Reg](const RegisterMaskPair &RM) {
182+
return RM.RegUnit == Reg;
183+
}))
184+
continue;
185+
186+
LaneBitmask UseMask;
187+
auto &LI = LIS.getInterval(Reg);
188+
if (!LI.hasSubRanges())
189+
UseMask = MRI.getMaxLaneMaskForVReg(Reg);
190+
else {
191+
// For a tentative schedule LIS isn't updated yet but livemask should
192+
// remain the same on any schedule. Subreg defs can be reordered but they
193+
// all must dominate uses anyway.
194+
if (!InstrSI)
195+
InstrSI = LIS.getInstructionIndex(*MO.getParent()).getBaseIndex();
196+
UseMask = getLiveLaneMask(LI, InstrSI, MRI);
197+
}
199198

200-
auto Reg = MO.getReg();
201-
auto I = llvm::find_if(
202-
Res, [Reg](const RegisterMaskPair &RM) { return RM.RegUnit == Reg; });
203-
if (I != Res.end())
204-
I->LaneMask |= UsedMask;
205-
else
206-
Res.push_back(RegisterMaskPair(Reg, UsedMask));
199+
RegMaskPairs.emplace_back(Reg, UseMask);
207200
}
208-
return Res;
209201
}
210202

211203
///////////////////////////////////////////////////////////////////////////////
212204
// GCNRPTracker
213205

214-
LaneBitmask llvm::getLiveLaneMask(unsigned Reg,
215-
SlotIndex SI,
206+
LaneBitmask llvm::getLiveLaneMask(unsigned Reg, SlotIndex SI,
216207
const LiveIntervals &LIS,
217208
const MachineRegisterInfo &MRI) {
209+
return getLiveLaneMask(LIS.getInterval(Reg), SI, MRI);
210+
}
211+
212+
LaneBitmask llvm::getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
213+
const MachineRegisterInfo &MRI) {
218214
LaneBitmask LiveMask;
219-
const auto &LI = LIS.getInterval(Reg);
220215
if (LI.hasSubRanges()) {
221216
for (const auto &S : LI.subranges())
222217
if (S.liveAt(SI)) {
223218
LiveMask |= S.LaneMask;
224-
assert(LiveMask < MRI.getMaxLaneMaskForVReg(Reg) ||
225-
LiveMask == MRI.getMaxLaneMaskForVReg(Reg));
219+
assert(LiveMask == (LiveMask & MRI.getMaxLaneMaskForVReg(LI.reg())));
226220
}
227221
} else if (LI.liveAt(SI)) {
228-
LiveMask = MRI.getMaxLaneMaskForVReg(Reg);
222+
LiveMask = MRI.getMaxLaneMaskForVReg(LI.reg());
229223
}
230224
return LiveMask;
231225
}
@@ -261,15 +255,14 @@ void GCNRPTracker::reset(const MachineInstr &MI,
261255
MaxPressure = CurPressure = getRegPressure(*MRI, LiveRegs);
262256
}
263257

264-
void GCNUpwardRPTracker::reset(const MachineInstr &MI,
265-
const LiveRegSet *LiveRegsCopy) {
266-
GCNRPTracker::reset(MI, LiveRegsCopy, true);
267-
}
258+
////////////////////////////////////////////////////////////////////////////////
259+
// GCNUpwardRPTracker
268260

269261
void GCNUpwardRPTracker::reset(const MachineRegisterInfo &MRI_,
270262
const LiveRegSet &LiveRegs_) {
271263
MRI = &MRI_;
272264
LiveRegs = LiveRegs_;
265+
LastTrackedMI = nullptr;
273266
MaxPressure = CurPressure = getRegPressure(MRI_, LiveRegs_);
274267
}
275268

@@ -281,41 +274,55 @@ void GCNUpwardRPTracker::recede(const MachineInstr &MI) {
281274
if (MI.isDebugInstr())
282275
return;
283276

284-
auto const RegUses = collectVirtualRegUses(MI, LIS, *MRI);
285-
286-
// calc pressure at the MI (defs + uses)
287-
auto AtMIPressure = CurPressure;
288-
for (const auto &U : RegUses) {
289-
auto LiveMask = LiveRegs[U.RegUnit];
290-
AtMIPressure.inc(U.RegUnit, LiveMask, LiveMask | U.LaneMask, *MRI);
291-
}
292-
// update max pressure
293-
MaxPressure = max(AtMIPressure, MaxPressure);
294-
295-
for (const auto &MO : MI.all_defs()) {
296-
if (!MO.getReg().isVirtual() || MO.isDead())
297-
continue;
298-
299-
auto Reg = MO.getReg();
277+
auto DecrementDef = [this](const MachineOperand &MO) {
278+
Register Reg = MO.getReg();
300279
auto I = LiveRegs.find(Reg);
301280
if (I == LiveRegs.end())
302-
continue;
303-
auto &LiveMask = I->second;
304-
auto PrevMask = LiveMask;
281+
return;
282+
283+
LaneBitmask &LiveMask = I->second;
284+
LaneBitmask PrevMask = LiveMask;
305285
LiveMask &= ~getDefRegMask(MO, *MRI);
306286
CurPressure.inc(Reg, PrevMask, LiveMask, *MRI);
307287
if (LiveMask.none())
308288
LiveRegs.erase(I);
289+
};
290+
291+
// Decrement non-early-clobber defs.
292+
SmallVector<const MachineOperand *, 2> EarlyClobberDefs;
293+
for (const MachineOperand &MO : MI.all_defs()) {
294+
if (!MO.getReg().isVirtual())
295+
continue;
296+
if (!MO.isEarlyClobber())
297+
DecrementDef(MO);
298+
else
299+
EarlyClobberDefs.push_back(&MO);
309300
}
310-
for (const auto &U : RegUses) {
311-
auto &LiveMask = LiveRegs[U.RegUnit];
312-
auto PrevMask = LiveMask;
301+
302+
// Increment uses.
303+
SmallVector<RegisterMaskPair, 8> RegUses;
304+
collectVirtualRegUses(RegUses, MI, LIS, *MRI);
305+
for (const RegisterMaskPair &U : RegUses) {
306+
LaneBitmask &LiveMask = LiveRegs[U.RegUnit];
307+
LaneBitmask PrevMask = LiveMask;
313308
LiveMask |= U.LaneMask;
314309
CurPressure.inc(U.RegUnit, PrevMask, LiveMask, *MRI);
315310
}
311+
312+
// Point of maximum pressure: non-early-clobber defs are decremented and uses
313+
// are incremented.
314+
MaxPressure = max(CurPressure, MaxPressure);
315+
316+
// Now decrement early clobber defs.
317+
for (const MachineOperand *MO : EarlyClobberDefs)
318+
DecrementDef(*MO);
319+
316320
assert(CurPressure == getRegPressure(*MRI, LiveRegs));
317321
}
318322

323+
////////////////////////////////////////////////////////////////////////////////
324+
// GCNDownwardRPTracker
325+
319326
bool GCNDownwardRPTracker::reset(const MachineInstr &MI,
320327
const LiveRegSet *LiveRegsCopy) {
321328
MRI = &MI.getParent()->getParent()->getRegInfo();
@@ -562,15 +569,15 @@ bool GCNRegPressurePrinter::runOnMachineFunction(MachineFunction &MF) {
562569
} else {
563570
GCNUpwardRPTracker RPT(LIS);
564571
RPT.reset(MRI, MBBEndSlot);
565-
RPT.moveMaxPressure(); // Clear max pressure.
566572

567573
LiveOut = RPT.getLiveRegs();
568574
RPAtMBBEnd = RPT.getPressure();
569575

570576
for (auto &MI : reverse(MBB)) {
577+
RPT.resetMaxPressure();
571578
RPT.recede(MI);
572579
if (!MI.isDebugInstr())
573-
RP.emplace_back(RPT.getPressure(), RPT.moveMaxPressure());
580+
RP.emplace_back(RPT.getPressure(), RPT.getMaxPressure());
574581
}
575582

576583
LiveIn = RPT.getLiveRegs();

llvm/lib/Target/AMDGPU/GCNRegPressure.h

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,6 @@ class GCNRPTracker {
130130

131131
GCNRegPressure getPressure() const { return CurPressure; }
132132

133-
// returns MaxPressure, resetting it
134-
decltype(MaxPressure) moveMaxPressure() {
135-
auto Res = MaxPressure;
136-
MaxPressure.clear();
137-
return Res;
138-
}
139-
140133
decltype(LiveRegs) moveLiveRegs() {
141134
return std::move(LiveRegs);
142135
}
@@ -149,24 +142,41 @@ class GCNUpwardRPTracker : public GCNRPTracker {
149142
public:
150143
GCNUpwardRPTracker(const LiveIntervals &LIS_) : GCNRPTracker(LIS_) {}
151144

152-
// reset tracker to the point just below MI
153-
// filling live regs upon this point using LIS
154-
void reset(const MachineInstr &MI, const LiveRegSet *LiveRegs = nullptr);
155-
156145
// reset tracker and set live register set to the specified value.
157146
void reset(const MachineRegisterInfo &MRI_, const LiveRegSet &LiveRegs_);
158147

159148
// reset tracker at the specified slot index.
160-
void reset(const MachineRegisterInfo &MRI_, SlotIndex SI) {
161-
reset(MRI_, llvm::getLiveRegs(SI, LIS, MRI_));
149+
void reset(const MachineRegisterInfo &MRI, SlotIndex SI) {
150+
reset(MRI, llvm::getLiveRegs(SI, LIS, MRI));
162151
}
163152

164-
// move to the state just above the MI
153+
// reset tracker to the end of the MBB.
154+
void reset(const MachineBasicBlock &MBB) {
155+
reset(MBB.getParent()->getRegInfo(),
156+
LIS.getSlotIndexes()->getMBBEndIdx(&MBB));
157+
}
158+
159+
// reset tracker to the point just after MI (in program order).
160+
void reset(const MachineInstr &MI) {
161+
reset(MI.getMF()->getRegInfo(), LIS.getInstructionIndex(MI).getDeadSlot());
162+
}
163+
164+
// move to the state just before the MI (in program order).
165165
void recede(const MachineInstr &MI);
166166

167167
// checks whether the tracker's state after receding MI corresponds
168-
// to reported by LIS
168+
// to reported by LIS.
169169
bool isValid() const;
170+
171+
const GCNRegPressure &getMaxPressure() const { return MaxPressure; }
172+
173+
void resetMaxPressure() { MaxPressure = CurPressure; }
174+
175+
GCNRegPressure getMaxPressureAndReset() {
176+
GCNRegPressure RP = MaxPressure;
177+
resetMaxPressure();
178+
return RP;
179+
}
170180
};
171181

172182
class GCNDownwardRPTracker : public GCNRPTracker {
@@ -180,6 +190,13 @@ class GCNDownwardRPTracker : public GCNRPTracker {
180190

181191
MachineBasicBlock::const_iterator getNext() const { return NextMI; }
182192

193+
// Return MaxPressure and clear it.
194+
GCNRegPressure moveMaxPressure() {
195+
auto Res = MaxPressure;
196+
MaxPressure.clear();
197+
return Res;
198+
}
199+
183200
// Reset tracker to the point before the MI
184201
// filling live regs upon this point using LIS.
185202
// Returns false if block is empty except debug values.
@@ -209,6 +226,12 @@ LaneBitmask getLiveLaneMask(unsigned Reg,
209226
const LiveIntervals &LIS,
210227
const MachineRegisterInfo &MRI);
211228

229+
LaneBitmask getLiveLaneMask(const LiveInterval &LI, SlotIndex SI,
230+
const MachineRegisterInfo &MRI);
231+
232+
GCNRPTracker::LiveRegSet getLiveRegs(SlotIndex SI, const LiveIntervals &LIS,
233+
const MachineRegisterInfo &MRI);
234+
212235
/// creates a map MachineInstr -> LiveRegSet
213236
/// R - range of iterators on instructions
214237
/// After - upon entry or exit of every instruction

0 commit comments

Comments
 (0)