Skip to content

[AMDGPU] Simplify WaitcntBrackets::getRegInterval with getPhysRegBaseClass #74087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 9 additions & 11 deletions llvm/lib/Target/AMDGPU/SIInsertWaitcnts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ class WaitcntBrackets {

bool merge(const WaitcntBrackets &Other);

RegInterval getRegInterval(const MachineInstr *MI, const SIInstrInfo *TII,
RegInterval getRegInterval(const MachineInstr *MI,
const MachineRegisterInfo *MRI,
const SIRegisterInfo *TRI, unsigned OpNo) const;

Expand Down Expand Up @@ -491,7 +491,6 @@ class SIInsertWaitcnts : public MachineFunctionPass {
} // end anonymous namespace

RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
const SIInstrInfo *TII,
const MachineRegisterInfo *MRI,
const SIRegisterInfo *TRI,
unsigned OpNo) const {
Expand Down Expand Up @@ -525,7 +524,7 @@ RegInterval WaitcntBrackets::getRegInterval(const MachineInstr *MI,
else
return {-1, -1};

const TargetRegisterClass *RC = TII->getOpRegClass(*MI, OpNo);
const TargetRegisterClass *RC = TRI->getPhysRegBaseClass(Op.getReg());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably be asserting this is a physical register here?
i.e. assert(!Op.getReg().isVirtual())

getPhysRegBaseClass will return null for virtuals, etc, so alternatively/additionally assert RC is valid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice I think this would have failed earlier, at the call to TRI->getEncodingValue above. Do you still think an explicit assert is needed somewhere?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine without an assert, it's not like any behaviour was changed, and misuse is unlikely.

I am only raising it because we can clearly can identify non physicals here based on Register type, where as most of the TRI methods are working on MCRegister and only identify out-of-bound registers.
We can imagine Register to MCRegister truncation yielding valid hardware registers for virtuals.

unsigned Size = TRI->getRegSizeInBits(*RC);
Result.second = Result.first + ((Size + 16) / 32);

Expand All @@ -537,7 +536,7 @@ void WaitcntBrackets::setExpScore(const MachineInstr *MI,
const SIRegisterInfo *TRI,
const MachineRegisterInfo *MRI, unsigned OpNo,
unsigned Val) {
RegInterval Interval = getRegInterval(MI, TII, MRI, TRI, OpNo);
RegInterval Interval = getRegInterval(MI, MRI, TRI, OpNo);
assert(TRI->isVectorRegister(*MRI, MI->getOperand(OpNo).getReg()));
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
setRegScore(RegNo, EXP_CNT, Val);
Expand Down Expand Up @@ -673,7 +672,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
Inst.getOpcode() == AMDGPU::BUFFER_STORE_DWORDX4) {
MachineOperand *MO = TII->getNamedOperand(Inst, AMDGPU::OpName::data);
unsigned OpNo;//TODO: find the OpNo for this operand;
RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, OpNo);
RegInterval Interval = getRegInterval(&Inst, MRI, TRI, OpNo);
for (int RegNo = Interval.first; RegNo < Interval.second;
++RegNo) {
setRegScore(RegNo + NUM_ALL_VGPRS, t, CurrScore);
Expand All @@ -685,7 +684,7 @@ void WaitcntBrackets::updateByEvent(const SIInstrInfo *TII,
auto &Op = Inst.getOperand(I);
if (!Op.isReg() || !Op.isDef())
continue;
RegInterval Interval = getRegInterval(&Inst, TII, MRI, TRI, I);
RegInterval Interval = getRegInterval(&Inst, MRI, TRI, I);
if (T == VM_CNT) {
if (Interval.first >= NUM_ALL_VGPRS)
continue;
Expand Down Expand Up @@ -1136,7 +1135,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,

if (MI.getOperand(CallAddrOpIdx).isReg()) {
RegInterval CallAddrOpInterval =
ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, CallAddrOpIdx);
ScoreBrackets.getRegInterval(&MI, MRI, TRI, CallAddrOpIdx);

for (int RegNo = CallAddrOpInterval.first;
RegNo < CallAddrOpInterval.second; ++RegNo)
Expand All @@ -1146,7 +1145,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
AMDGPU::getNamedOperandIdx(MI.getOpcode(), AMDGPU::OpName::dst);
if (RtnAddrOpIdx != -1) {
RegInterval RtnAddrOpInterval =
ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, RtnAddrOpIdx);
ScoreBrackets.getRegInterval(&MI, MRI, TRI, RtnAddrOpIdx);

for (int RegNo = RtnAddrOpInterval.first;
RegNo < RtnAddrOpInterval.second; ++RegNo)
Expand Down Expand Up @@ -1198,8 +1197,7 @@ bool SIInsertWaitcnts::generateWaitcntInstBefore(MachineInstr &MI,
if (Op.isTied() && Op.isUse() && TII->doesNotReadTiedSource(MI))
continue;

RegInterval Interval =
ScoreBrackets.getRegInterval(&MI, TII, MRI, TRI, I);
RegInterval Interval = ScoreBrackets.getRegInterval(&MI, MRI, TRI, I);

const bool IsVGPR = TRI->isVectorRegister(*MRI, Op.getReg());
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
Expand Down Expand Up @@ -1775,7 +1773,7 @@ bool SIInsertWaitcnts::shouldFlushVmCnt(MachineLoop *ML,
MachineOperand &Op = MI.getOperand(I);
if (!Op.isReg() || !TRI->isVectorRegister(*MRI, Op.getReg()))
continue;
RegInterval Interval = Brackets.getRegInterval(&MI, TII, MRI, TRI, I);
RegInterval Interval = Brackets.getRegInterval(&MI, MRI, TRI, I);
// Vgpr use
if (Op.isUse()) {
for (int RegNo = Interval.first; RegNo < Interval.second; ++RegNo) {
Expand Down