Skip to content

Commit 62e4436

Browse files
authored
[NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) (#127898)
Prior to this change NVPTXReplaceImageHandles replaced operands with indices and populated a table matching these indices to strings to be used in AsmPrinter. We can clean this up by simply inserting the correct external symbol or global address operands during NVPTXReplaceImageHandles, largely removing the need for the table.
1 parent 95000fd commit 62e4436

File tree

4 files changed

+38
-146
lines changed

4 files changed

+38
-146
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 21 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -149,67 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
149149
EmitToStreamer(*OutStreamer, Inst);
150150
}
151151

152-
// Handle symbol backtracking for targets that do not support image handles
153-
bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
154-
unsigned OpNo, MCOperand &MCOp) {
155-
const MachineOperand &MO = MI->getOperand(OpNo);
156-
const MCInstrDesc &MCID = MI->getDesc();
157-
158-
if (MCID.TSFlags & NVPTXII::IsTexFlag) {
159-
// This is a texture fetch, so operand 4 is a texref and operand 5 is
160-
// a samplerref
161-
if (OpNo == 4 && MO.isImm()) {
162-
lowerImageHandleSymbol(MO.getImm(), MCOp);
163-
return true;
164-
}
165-
if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
166-
lowerImageHandleSymbol(MO.getImm(), MCOp);
167-
return true;
168-
}
169-
170-
return false;
171-
} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
172-
unsigned VecSize =
173-
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
174-
175-
// For a surface load of vector size N, the Nth operand will be the surfref
176-
if (OpNo == VecSize && MO.isImm()) {
177-
lowerImageHandleSymbol(MO.getImm(), MCOp);
178-
return true;
179-
}
180-
181-
return false;
182-
} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
183-
// This is a surface store, so operand 0 is a surfref
184-
if (OpNo == 0 && MO.isImm()) {
185-
lowerImageHandleSymbol(MO.getImm(), MCOp);
186-
return true;
187-
}
188-
189-
return false;
190-
} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
191-
// This is a query, so operand 1 is a surfref/texref
192-
if (OpNo == 1 && MO.isImm()) {
193-
lowerImageHandleSymbol(MO.getImm(), MCOp);
194-
return true;
195-
}
196-
197-
return false;
198-
}
199-
200-
return false;
201-
}
202-
203-
void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
204-
// Ewwww
205-
TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
206-
NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
207-
const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
208-
StringRef Sym = MFI->getImageHandleSymbol(Index);
209-
StringRef SymName = nvTM.getStrPool().save(Sym);
210-
MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
211-
}
212-
213152
void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
214153
OutMI.setOpcode(MI->getOpcode());
215154
// Special: Do not mangle symbol operand of CALL_PROTOTYPE
@@ -220,67 +159,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
220159
return;
221160
}
222161

223-
for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
224-
const MachineOperand &MO = MI->getOperand(i);
225-
226-
MCOperand MCOp;
227-
if (lowerImageHandleOperand(MI, i, MCOp)) {
228-
OutMI.addOperand(MCOp);
229-
continue;
230-
}
231-
232-
if (lowerOperand(MO, MCOp))
233-
OutMI.addOperand(MCOp);
234-
}
162+
for (const auto MO : MI->operands())
163+
OutMI.addOperand(lowerOperand(MO));
235164
}
236165

237-
bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
238-
MCOperand &MCOp) {
166+
MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
239167
switch (MO.getType()) {
240-
default: llvm_unreachable("unknown operand type");
168+
default:
169+
llvm_unreachable("unknown operand type");
241170
case MachineOperand::MO_Register:
242-
MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
243-
break;
171+
return MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
244172
case MachineOperand::MO_Immediate:
245-
MCOp = MCOperand::createImm(MO.getImm());
246-
break;
173+
return MCOperand::createImm(MO.getImm());
247174
case MachineOperand::MO_MachineBasicBlock:
248-
MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
249-
MO.getMBB()->getSymbol(), OutContext));
250-
break;
175+
return MCOperand::createExpr(
176+
MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext));
251177
case MachineOperand::MO_ExternalSymbol:
252-
MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
253-
break;
178+
return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
254179
case MachineOperand::MO_GlobalAddress:
255-
MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
256-
break;
180+
return GetSymbolRef(getSymbol(MO.getGlobal()));
257181
case MachineOperand::MO_FPImmediate: {
258182
const ConstantFP *Cnt = MO.getFPImm();
259183
const APFloat &Val = Cnt->getValueAPF();
260184

261185
switch (Cnt->getType()->getTypeID()) {
262-
default: report_fatal_error("Unsupported FP type"); break;
263-
case Type::HalfTyID:
264-
MCOp = MCOperand::createExpr(
265-
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
186+
default:
187+
report_fatal_error("Unsupported FP type");
266188
break;
189+
case Type::HalfTyID:
190+
return MCOperand::createExpr(
191+
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
267192
case Type::BFloatTyID:
268-
MCOp = MCOperand::createExpr(
193+
return MCOperand::createExpr(
269194
NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
270-
break;
271195
case Type::FloatTyID:
272-
MCOp = MCOperand::createExpr(
273-
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
274-
break;
196+
return MCOperand::createExpr(
197+
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
275198
case Type::DoubleTyID:
276-
MCOp = MCOperand::createExpr(
277-
NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
278-
break;
199+
return MCOperand::createExpr(
200+
NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
279201
}
280202
break;
281203
}
282204
}
283-
return true;
284205
}
285206

286207
unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
163163

164164
void emitInstruction(const MachineInstr *) override;
165165
void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
166-
bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
166+
MCOperand lowerOperand(const MachineOperand &MO);
167167
MCOperand GetSymbolRef(const MCSymbol *Symbol);
168168
unsigned encodeVirtualRegister(unsigned Reg);
169169

@@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
226226
void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O);
227227
void emitDemotedVars(const Function *, raw_ostream &);
228228

229-
bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
230-
MCOperand &MCOp);
231-
void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
232-
233229
bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
234230

235231
// Used to control the need to emit .generic() in the initializer of

llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo {
4747
return ImageHandleList.size()-1;
4848
}
4949

50-
/// Returns the symbol name at the given index.
51-
StringRef getImageHandleSymbol(unsigned Idx) const {
52-
assert(ImageHandleList.size() > Idx && "Bad index");
53-
return ImageHandleList[Idx];
54-
}
55-
5650
/// Check if the symbol has a mapping. Having a mapping means the handle is
5751
/// replaced with a reference
5852
bool checkImageHandleSymbol(StringRef Symbol) const {

llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "llvm/CodeGen/MachineFunction.h"
2121
#include "llvm/CodeGen/MachineFunctionPass.h"
2222
#include "llvm/CodeGen/MachineRegisterInfo.h"
23-
#include "llvm/Support/raw_ostream.h"
2423

2524
using namespace llvm;
2625

@@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass {
4140
private:
4241
bool processInstr(MachineInstr &MI);
4342
bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF);
44-
bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF,
45-
unsigned &Idx);
4643
};
47-
}
44+
} // namespace
4845

4946
char NVPTXReplaceImageHandles::ID = 0;
5047

@@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
17561753
}
17571754

17581755
return true;
1759-
} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
1756+
}
1757+
if (MCID.TSFlags & NVPTXII::IsSuldMask) {
17601758
unsigned VecSize =
1761-
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
1759+
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) -
1760+
1);
17621761

17631762
// For a surface load of vector size N, the Nth operand will be the surfref
17641763
MachineOperand &SurfHandle = MI.getOperand(VecSize);
@@ -1767,15 +1766,17 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
17671766
MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode())));
17681767

17691768
return true;
1770-
} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
1769+
}
1770+
if (MCID.TSFlags & NVPTXII::IsSustFlag) {
17711771
// This is a surface store, so operand 0 is a surfref
17721772
MachineOperand &SurfHandle = MI.getOperand(0);
17731773

17741774
if (replaceImageHandle(SurfHandle, MF))
17751775
MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode())));
17761776

17771777
return true;
1778-
} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
1778+
}
1779+
if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
17791780
// This is a query, so operand 1 is a surfref/texref
17801781
MachineOperand &Handle = MI.getOperand(1);
17811782

@@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
17901791

17911792
bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
17921793
MachineFunction &MF) {
1793-
unsigned Idx;
1794-
if (findIndexForHandle(Op, MF, Idx)) {
1795-
Op.ChangeToImmediate(Idx);
1796-
return true;
1797-
}
1798-
return false;
1799-
}
1800-
1801-
bool NVPTXReplaceImageHandles::
1802-
findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
18031794
const MachineRegisterInfo &MRI = MF.getRegInfo();
18041795
NVPTXMachineFunctionInfo *MFI = MF.getInfo<NVPTXMachineFunctionInfo>();
18051796

@@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
18121803
case NVPTX::LD_i64_avar: {
18131804
// The handle is a parameter value being loaded, replace with the
18141805
// parameter symbol
1815-
const NVPTXTargetMachine &TM =
1816-
static_cast<const NVPTXTargetMachine &>(MF.getTarget());
1817-
if (TM.getDrvInterface() == NVPTX::CUDA) {
1806+
const auto &TM = static_cast<const NVPTXTargetMachine &>(MF.getTarget());
1807+
if (TM.getDrvInterface() == NVPTX::CUDA)
18181808
// For CUDA, we preserve the param loads coming from function arguments
18191809
return false;
1820-
}
18211810

18221811
assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
18231812
StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
1824-
std::string ParamBaseName = std::string(MF.getName());
1825-
ParamBaseName += "_param_";
1826-
assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference");
1827-
unsigned Param = atoi(Sym.data()+ParamBaseName.size());
1828-
std::string NewSym;
1829-
raw_string_ostream NewSymStr(NewSym);
1830-
NewSymStr << MF.getName() << "_param_" << Param;
1831-
18321813
InstrsToRemove.insert(&TexHandleDef);
1833-
Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str());
1814+
Op.ChangeToES(Sym.data());
1815+
MFI->getImageHandleSymbolIndex(Sym);
18341816
return true;
18351817
}
18361818
case NVPTX::texsurf_handles: {
@@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
18391821
const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal();
18401822
assert(GV->hasName() && "Global sampler must be named!");
18411823
InstrsToRemove.insert(&TexHandleDef);
1842-
Idx = MFI->getImageHandleSymbolIndex(GV->getName());
1824+
Op.ChangeToGA(GV, 0);
18431825
return true;
18441826
}
18451827
case NVPTX::nvvm_move_i64:
18461828
case TargetOpcode::COPY: {
1847-
bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx);
1848-
if (Res) {
1829+
bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF);
1830+
if (Res)
18491831
InstrsToRemove.insert(&TexHandleDef);
1850-
}
18511832
return Res;
18521833
}
18531834
default:

0 commit comments

Comments
 (0)