-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AMDGPU] Simplify WaitcntBrackets::getRegInterval with getPhysRegBaseClass #74087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…Class This means that getRegInterval no longer depends on the MCInstrDesc, so it could be simplified further to take just a MachineOperand or just a physical register. NFCI.
@llvm/pr-subscribers-backend-amdgpu Author: Jay Foad (jayfoad) ChangesThis means that getRegInterval no longer depends on the MCInstrDesc, so it could be simplified further to take just a MachineOperand or just a physical register. NFCI. Full diff: https://github.com/llvm/llvm-project/pull/74087.diff 1 Files Affected:
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index ede4841b8a5fd7d..d1061d786706636 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -238,7 +238,7 @@ class WaitcntBrackets {
bool merge(const WaitcntBrackets &Other);
- RegInterval getRegInterval(const MachineInstr *MI, const SIInstrInfo *TII,
+ RegInterval getRegInterval(const MachineInstr *MI,
const MachineRegisterInfo *MRI,
const SIRegisterInfo *TRI, unsigned OpNo) const;
@@ -491,7 +491,6 @@ class SIInsertWaitcnts : public MachineFunctionPass {
} // end anonymous namespace
RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
- const SIInstrInfo *TII,
const MachineRegisterInfo *MRI,
const SIRegisterInfo *TRI,
unsigned OpNo) const {
@@ -525,7 +524,7 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
else
return {-1, -1};
- const TargetRegisterClass *RC = TII->getOpRegClass(*MI, OpNo);
+ const TargetRegisterClass *RC = TRI->getPhysRegBaseClass(Op.getReg());
unsigned Size = TRI->getRegSizeInBits(*RC);
Result.second = Result.first + ((Size + 16) / 32);
@@ -537,7 +536,7 @@ void WaitcntBrackets::setExpScore(const MachineInstr *MI,
const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI, unsigned OpNo,
unsigned Val) {
- RegInterval Interval = getRegInterval(MI, TII, MRI, TRI, OpNo);
+ RegInterval Interval = getRegInterval(MI, MRI, TRI, OpNo);
assert(TRI->isVectorRegister(*MRI, MI->getOperand(OpNo).getReg()));
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
setRegScore(RegNo, EXP_CNT, Val);
@@ -673,7 +672,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
Inst.getOpcode() == AMDGPU::BUFFER_STORE_DWORDX4) {
MachineOperand *MO = TII->getNamedOperand(Inst, AMDGPU::OpName::data);
unsigned OpNo;//TODO: find the OpNo for this operand;
- RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, OpNo);
+ RegInterval Interval = getRegInterval(&Inst, MRI, TRI, OpNo);
for (int RegNo = Interval.first; RegNo < Interval.second;
++RegNo) {
setRegScore(RegNo + NUM_ALL_VGPRS, t, CurrScore);
@@ -685,7 +684,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
auto &Op = Inst.getOperand(I);
if (!Op.isReg() || !Op.isDef())
continue;
- RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, I);
+ RegInterval Interval = getRegInterval(&Inst, MRI, TRI, I);
if (T == VM_CNT) {
if (Interval.first >= NUM_ALL_VGPRS)
continue;
@@ -1136,7 +1135,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
if (MI.getOperand(CallAddrOpIdx).isReg()) {
RegInterval CallAddrOpInterval =
- ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, CallAddrOpIdx);
+ ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOpIdx);
for (int RegNo = CallAddrOpInterval.first;
RegNo < CallAddrOpInterval.second; ++RegNo)
@@ -1146,7 +1145,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::dst);
if (RtnAddrOpIdx != -1) {
RegInterval RtnAddrOpInterval =
- ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, RtnAddrOpIdx);
+ ScoreBrackets.getRegInterval(&MI, MRI, TRI, RtnAddrOpIdx);
for (int RegNo = RtnAddrOpInterval.first;
RegNo < RtnAddrOpInterval.second; ++RegNo)
@@ -1198,8 +1197,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
if (Op.isTied() && Op.isUse() && TII->doesNotReadTiedSource(MI))
continue;
- RegInterval Interval =
- ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, I);
+ RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, I);
const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
@@ -1775,7 +1773,7 @@ bool SIInsertWaitcnts::shouldFlushVmCnt(MachineLoop *ML,
MachineOperand &Op = MI.getOperand(I);
if (!Op.isReg() || !TRI->isVectorRegister(*MRI, Op.getReg()))
continue;
- RegInterval Interval = Brackets.getRegInterval(&MI, TII, MRI, TRI, I);
+ RegInterval Interval = Brackets.getRegInterval(&MI, MRI, TRI, I);
// Vgpr use
if (Op.isUse()) {
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with one nit. w.r.t. adding assert
@@ -525,7 +524,7 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI, | |||
else | |||
return {-1, -1}; | |||
|
|||
const TargetRegisterClass *RC = TII->getOpRegClass(*MI, OpNo); | |||
const TargetRegisterClass *RC = TRI->getPhysRegBaseClass(Op.getReg()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably be asserting this is a physical register here?
i.e. assert(!Op.getReg().isVirtual())
getPhysRegBaseClass
will return null for virtuals, etc, so alternatively/additionally assert RC is valid?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In practice I think this would have failed earlier, at the call to TRI->getEncodingValue
above. Do you still think an explicit assert is needed somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's fine without an assert, it's not like any behaviour was changed, and misuse is unlikely.
I am only raising it because we can clearly can identify non physicals here based on Register
type, where as most of the TRI methods are working on MCRegister
and only identify out-of-bound registers.
We can imagine Register
to MCRegister
truncation yielding valid hardware registers for virtuals.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should absolutely try to make use of the MCInstrDesc. Looking up the physical register class is extremely expensive. At one point the waitcnt pass was one of the most expensive backend passes, spending most of its time looking up physical register classes
That is the whole point of getPhysRegBaseClass - it is just two table lookups so should be fast: https://reviews.llvm.org/D139616 |
@arsenm ping |
This means that getRegInterval no longer depends on the MCInstrDesc, so it could be simplified further to take just a MachineOperand or just a physical register. NFCI.