Skip to content

[AMDGPU] Simplify WaitcntBrackets::getRegInterval with getPhysRegBaseClass #74087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023
Merged

[AMDGPU] Simplify WaitcntBrackets::getRegInterval with getPhysRegBaseClass #74087

merged 1 commit into from
Dec 18, 2023

Conversation

jayfoad
Copy link
Contributor

@jayfoad jayfoad commented Dec 1, 2023

This means that getRegInterval no longer depends on the MCInstrDesc, so it could be simplified further to take just a MachineOperand or just a physical register. NFCI.

…Class

This means that getRegInterval no longer depends on the MCInstrDesc, so
it could be simplified further to take just a MachineOperand or just a
physical register. NFCI.
@llvmbot
Copy link
Member

llvmbot commented Dec 1, 2023

@llvm/pr-subscribers-backend-amdgpu

Author: Jay Foad (jayfoad)

Changes

This means that getRegInterval no longer depends on the MCInstrDesc, so it could be simplified further to take just a MachineOperand or just a physical register. NFCI.


Full diff: https://github.com/llvm/llvm-project/pull/74087.diff

1 Files Affected:

  • (modified) llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp (+9-11)
diff --git a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
index ede4841b8a5fd7d..d1061d786706636 100644
--- a/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
+++ b/llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
@@ -238,7 +238,7 @@ class WaitcntBrackets {
 
   bool merge(const WaitcntBrackets &Other);
 
-  RegInterval getRegInterval(const MachineInstr *MI, const SIInstrInfo *TII,
+  RegInterval getRegInterval(const MachineInstr *MI,
                              const MachineRegisterInfo *MRI,
                              const SIRegisterInfo *TRI, unsigned OpNo) const;
 
@@ -491,7 +491,6 @@ class SIInsertWaitcnts : public MachineFunctionPass {
 } // end anonymous namespace
 
 RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
-                                            const SIInstrInfo *TII,
                                             const MachineRegisterInfo *MRI,
                                             const SIRegisterInfo *TRI,
                                             unsigned OpNo) const {
@@ -525,7 +524,7 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
   else
     return {-1, -1};
 
-  const TargetRegisterClass *RC = TII->getOpRegClass(*MI, OpNo);
+  const TargetRegisterClass *RC = TRI->getPhysRegBaseClass(Op.getReg());
   unsigned Size = TRI->getRegSizeInBits(*RC);
   Result.second = Result.first + ((Size + 16) / 32);
 
@@ -537,7 +536,7 @@ void WaitcntBrackets::setExpScore(const MachineInstr *MI,
                                   const SIRegisterInfo *TRI,
                                   const MachineRegisterInfo *MRI, unsigned OpNo,
                                   unsigned Val) {
-  RegInterval Interval = getRegInterval(MI, TII, MRI, TRI, OpNo);
+  RegInterval Interval = getRegInterval(MI, MRI, TRI, OpNo);
   assert(TRI->isVectorRegister(*MRI, MI->getOperand(OpNo).getReg()));
   for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
     setRegScore(RegNo, EXP_CNT, Val);
@@ -673,7 +672,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
        Inst.getOpcode() == AMDGPU::BUFFER_STORE_DWORDX4) {
     MachineOperand *MO = TII->getNamedOperand(Inst, AMDGPU::OpName::data);
     unsigned OpNo;//TODO: find the OpNo for this operand;
-    RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, OpNo);
+    RegInterval Interval = getRegInterval(&Inst, MRI, TRI, OpNo);
     for (int RegNo = Interval.first; RegNo < Interval.second;
     ++RegNo) {
       setRegScore(RegNo + NUM_ALL_VGPRS, t, CurrScore);
@@ -685,7 +684,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
       auto &Op = Inst.getOperand(I);
       if (!Op.isReg() || !Op.isDef())
         continue;
-      RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, I);
+      RegInterval Interval = getRegInterval(&Inst, MRI, TRI, I);
       if (T == VM_CNT) {
         if (Interval.first >= NUM_ALL_VGPRS)
           continue;
@@ -1136,7 +1135,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
 
       if (MI.getOperand(CallAddrOpIdx).isReg()) {
         RegInterval CallAddrOpInterval =
-          ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, CallAddrOpIdx);
+            ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOpIdx);
 
         for (int RegNo = CallAddrOpInterval.first;
              RegNo < CallAddrOpInterval.second; ++RegNo)
@@ -1146,7 +1145,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
           AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::dst);
         if (RtnAddrOpIdx != -1) {
           RegInterval RtnAddrOpInterval =
-            ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, RtnAddrOpIdx);
+              ScoreBrackets.getRegInterval(&MI, MRI, TRI, RtnAddrOpIdx);
 
           for (int RegNo = RtnAddrOpInterval.first;
                RegNo < RtnAddrOpInterval.second; ++RegNo)
@@ -1198,8 +1197,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
         if (Op.isTied() && Op.isUse() && TII->doesNotReadTiedSource(MI))
           continue;
 
-        RegInterval Interval =
-            ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, I);
+        RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, I);
 
         const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
         for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
@@ -1775,7 +1773,7 @@ bool SIInsertWaitcnts::shouldFlushVmCnt(MachineLoop *ML,
         MachineOperand &Op = MI.getOperand(I);
         if (!Op.isReg() || !TRI->isVectorRegister(*MRI, Op.getReg()))
           continue;
-        RegInterval Interval = Brackets.getRegInterval(&MI, TII, MRI, TRI, I);
+        RegInterval Interval = Brackets.getRegInterval(&MI, MRI, TRI, I);
         // Vgpr use
         if (Op.isUse()) {
           for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {

Copy link
Contributor

@perlfu perlfu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with one nit. w.r.t. adding assert

@@ -525,7 +524,7 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
else
return {-1, -1};

const TargetRegisterClass *RC = TII->getOpRegClass(*MI, OpNo);
const TargetRegisterClass *RC = TRI->getPhysRegBaseClass(Op.getReg());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably be asserting this is a physical register here?
i.e. assert(!Op.getReg().isVirtual())

getPhysRegBaseClass will return null for virtuals, etc, so alternatively/additionally assert RC is valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice I think this would have failed earlier, at the call to TRI->getEncodingValue above. Do you still think an explicit assert is needed somewhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine without an assert, it's not like any behaviour was changed, and misuse is unlikely.

I am only raising it because we can clearly can identify non physicals here based on Register type, where as most of the TRI methods are working on MCRegister and only identify out-of-bound registers.
We can imagine Register to MCRegister truncation yielding valid hardware registers for virtuals.

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should absolutely try to make use of the MCInstrDesc. Looking up the physical register class is extremely expensive. At one point the waitcnt pass was one of the most expensive backend passes, spending most of its time looking up physical register classes

@jayfoad
Copy link
Contributor Author

jayfoad commented Dec 6, 2023

You should absolutely try to make use of the MCInstrDesc. Looking up the physical register class is extremely expensive. At one point the waitcnt pass was one of the most expensive backend passes, spending most of its time looking up physical register classes

That is the whole point of getPhysRegBaseClass - it is just two table lookups so should be fast: https://reviews.llvm.org/D139616

@jayfoad
Copy link
Contributor Author

jayfoad commented Dec 18, 2023

@arsenm ping

@jayfoad jayfoad merged commit 7e5019e into llvm:main Dec 18, 2023
@jayfoad jayfoad deleted the getreginterval-getphysregbaseclass branch December 18, 2023 14:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants