Skip to content

Commit 965c57c

Browse files
committed
[NVPTX] Use appropriate operands in ReplaceImageHandles (NFC)
1 parent e60de25 commit 965c57c

File tree

4 files changed

+38
-145
lines changed

4 files changed

+38
-145
lines changed

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp

Lines changed: 21 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -149,66 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
149149
EmitToStreamer(*OutStreamer, Inst);
150150
}
151151

152-
// Handle symbol backtracking for targets that do not support image handles
153-
bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
154-
unsigned OpNo, MCOperand &MCOp) {
155-
const MachineOperand &MO = MI->getOperand(OpNo);
156-
const MCInstrDesc &MCID = MI->getDesc();
157-
158-
if (MCID.TSFlags & NVPTXII::IsTexFlag) {
159-
// This is a texture fetch, so operand 4 is a texref and operand 5 is
160-
// a samplerref
161-
if (OpNo == 4 && MO.isImm()) {
162-
lowerImageHandleSymbol(MO.getImm(), MCOp);
163-
return true;
164-
}
165-
if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
166-
lowerImageHandleSymbol(MO.getImm(), MCOp);
167-
return true;
168-
}
169-
170-
return false;
171-
} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
172-
unsigned VecSize =
173-
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
174-
175-
// For a surface load of vector size N, the Nth operand will be the surfref
176-
if (OpNo == VecSize && MO.isImm()) {
177-
lowerImageHandleSymbol(MO.getImm(), MCOp);
178-
return true;
179-
}
180-
181-
return false;
182-
} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
183-
// This is a surface store, so operand 0 is a surfref
184-
if (OpNo == 0 && MO.isImm()) {
185-
lowerImageHandleSymbol(MO.getImm(), MCOp);
186-
return true;
187-
}
188-
189-
return false;
190-
} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
191-
// This is a query, so operand 1 is a surfref/texref
192-
if (OpNo == 1 && MO.isImm()) {
193-
lowerImageHandleSymbol(MO.getImm(), MCOp);
194-
return true;
195-
}
196-
197-
return false;
198-
}
199-
200-
return false;
201-
}
202-
203-
void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
204-
// Ewwww
205-
TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
206-
NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
207-
const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
208-
StringRef Sym = MFI->getImageHandleSymbol(Index);
209-
StringRef SymName = nvTM.getStrPool().save(Sym);
210-
MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
211-
}
212152

213153
void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
214154
OutMI.setOpcode(MI->getOpcode());
@@ -220,67 +160,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
220160
return;
221161
}
222162

223-
for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
224-
const MachineOperand &MO = MI->getOperand(i);
225-
226-
MCOperand MCOp;
227-
if (lowerImageHandleOperand(MI, i, MCOp)) {
228-
OutMI.addOperand(MCOp);
229-
continue;
230-
}
231-
232-
if (lowerOperand(MO, MCOp))
233-
OutMI.addOperand(MCOp);
234-
}
163+
for (const auto MO : MI->operands())
164+
OutMI.addOperand(lowerOperand(MO));
235165
}
236166

237-
bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
238-
MCOperand &MCOp) {
167+
MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
239168
switch (MO.getType()) {
240-
default: llvm_unreachable("unknown operand type");
169+
default:
170+
llvm_unreachable("unknown operand type");
241171
case MachineOperand::MO_Register:
242-
MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
243-
break;
172+
return MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
244173
case MachineOperand::MO_Immediate:
245-
MCOp = MCOperand::createImm(MO.getImm());
246-
break;
174+
return MCOperand::createImm(MO.getImm());
247175
case MachineOperand::MO_MachineBasicBlock:
248-
MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
249-
MO.getMBB()->getSymbol(), OutContext));
250-
break;
176+
return MCOperand::createExpr(
177+
MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext));
251178
case MachineOperand::MO_ExternalSymbol:
252-
MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
253-
break;
179+
return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
254180
case MachineOperand::MO_GlobalAddress:
255-
MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
256-
break;
181+
return GetSymbolRef(getSymbol(MO.getGlobal()));
257182
case MachineOperand::MO_FPImmediate: {
258183
const ConstantFP *Cnt = MO.getFPImm();
259184
const APFloat &Val = Cnt->getValueAPF();
260185

261186
switch (Cnt->getType()->getTypeID()) {
262-
default: report_fatal_error("Unsupported FP type"); break;
263-
case Type::HalfTyID:
264-
MCOp = MCOperand::createExpr(
265-
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
187+
default:
188+
report_fatal_error("Unsupported FP type");
266189
break;
190+
case Type::HalfTyID:
191+
return MCOperand::createExpr(
192+
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
267193
case Type::BFloatTyID:
268-
MCOp = MCOperand::createExpr(
194+
return MCOperand::createExpr(
269195
NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
270-
break;
271196
case Type::FloatTyID:
272-
MCOp = MCOperand::createExpr(
273-
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
274-
break;
197+
return MCOperand::createExpr(
198+
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
275199
case Type::DoubleTyID:
276-
MCOp = MCOperand::createExpr(
277-
NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
278-
break;
200+
return MCOperand::createExpr(
201+
NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
279202
}
280203
break;
281204
}
282205
}
283-
return true;
284206
}
285207

286208
unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {

llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
163163

164164
void emitInstruction(const MachineInstr *) override;
165165
void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
166-
bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
166+
MCOperand lowerOperand(const MachineOperand &MO);
167167
MCOperand GetSymbolRef(const MCSymbol *Symbol);
168168
unsigned encodeVirtualRegister(unsigned Reg);
169169

@@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
226226
void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O);
227227
void emitDemotedVars(const Function *, raw_ostream &);
228228

229-
bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
230-
MCOperand &MCOp);
231-
void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);
232-
233229
bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;
234230

235231
// Used to control the need to emit .generic() in the initializer of

llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo {
4747
return ImageHandleList.size()-1;
4848
}
4949

50-
/// Returns the symbol name at the given index.
51-
StringRef getImageHandleSymbol(unsigned Idx) const {
52-
assert(ImageHandleList.size() > Idx && "Bad index");
53-
return ImageHandleList[Idx];
54-
}
55-
5650
/// Check if the symbol has a mapping. Having a mapping means the handle is
5751
/// replaced with a reference
5852
bool checkImageHandleSymbol(StringRef Symbol) const {

llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp

Lines changed: 16 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
#include "llvm/CodeGen/MachineFunction.h"
2121
#include "llvm/CodeGen/MachineFunctionPass.h"
2222
#include "llvm/CodeGen/MachineRegisterInfo.h"
23-
#include "llvm/Support/raw_ostream.h"
2423

2524
using namespace llvm;
2625

@@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass {
4140
private:
4241
bool processInstr(MachineInstr &MI);
4342
bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF);
44-
bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF,
45-
unsigned &Idx);
4643
};
47-
}
44+
} // namespace
4845

4946
char NVPTXReplaceImageHandles::ID = 0;
5047

@@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
17561753
}
17571754

17581755
return true;
1759-
} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
1756+
}
1757+
if (MCID.TSFlags & NVPTXII::IsSuldMask) {
17601758
unsigned VecSize =
1761-
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
1759+
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) -
1760+
1);
17621761

17631762
// For a surface load of vector size N, the Nth operand will be the surfref
17641763
MachineOperand &SurfHandle = MI.getOperand(VecSize);
@@ -1767,15 +1766,17 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
17671766
MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode())));
17681767

17691768
return true;
1770-
} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
1769+
}
1770+
if (MCID.TSFlags & NVPTXII::IsSustFlag) {
17711771
// This is a surface store, so operand 0 is a surfref
17721772
MachineOperand &SurfHandle = MI.getOperand(0);
17731773

17741774
if (replaceImageHandle(SurfHandle, MF))
17751775
MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode())));
17761776

17771777
return true;
1778-
} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
1778+
}
1779+
if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
17791780
// This is a query, so operand 1 is a surfref/texref
17801781
MachineOperand &Handle = MI.getOperand(1);
17811782

@@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
17901791

17911792
bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
17921793
MachineFunction &MF) {
1793-
unsigned Idx;
1794-
if (findIndexForHandle(Op, MF, Idx)) {
1795-
Op.ChangeToImmediate(Idx);
1796-
return true;
1797-
}
1798-
return false;
1799-
}
1800-
1801-
bool NVPTXReplaceImageHandles::
1802-
findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
18031794
const MachineRegisterInfo &MRI = MF.getRegInfo();
18041795
NVPTXMachineFunctionInfo *MFI = MF.getInfo<NVPTXMachineFunctionInfo>();
18051796

@@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
18121803
case NVPTX::LD_i64_avar: {
18131804
// The handle is a parameter value being loaded, replace with the
18141805
// parameter symbol
1815-
const NVPTXTargetMachine &TM =
1816-
static_cast<const NVPTXTargetMachine &>(MF.getTarget());
1817-
if (TM.getDrvInterface() == NVPTX::CUDA) {
1806+
const auto &TM = static_cast<const NVPTXTargetMachine &>(MF.getTarget());
1807+
if (TM.getDrvInterface() == NVPTX::CUDA)
18181808
// For CUDA, we preserve the param loads coming from function arguments
18191809
return false;
1820-
}
18211810

18221811
assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
18231812
StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
1824-
std::string ParamBaseName = std::string(MF.getName());
1825-
ParamBaseName += "_param_";
1826-
assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference");
1827-
unsigned Param = atoi(Sym.data()+ParamBaseName.size());
1828-
std::string NewSym;
1829-
raw_string_ostream NewSymStr(NewSym);
1830-
NewSymStr << MF.getName() << "_param_" << Param;
1831-
18321813
InstrsToRemove.insert(&TexHandleDef);
1833-
Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str());
1814+
Op.ChangeToES(Sym.data());
1815+
MFI->getImageHandleSymbolIndex(Sym);
18341816
return true;
18351817
}
18361818
case NVPTX::texsurf_handles: {
@@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
18391821
const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal();
18401822
assert(GV->hasName() && "Global sampler must be named!");
18411823
InstrsToRemove.insert(&TexHandleDef);
1842-
Idx = MFI->getImageHandleSymbolIndex(GV->getName());
1824+
Op.ChangeToGA(GV, 0);
18431825
return true;
18441826
}
18451827
case NVPTX::nvvm_move_i64:
18461828
case TargetOpcode::COPY: {
1847-
bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx);
1848-
if (Res) {
1829+
bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF);
1830+
if (Res)
18491831
InstrsToRemove.insert(&TexHandleDef);
1850-
}
18511832
return Res;
18521833
}
18531834
default:

0 commit comments

Comments
 (0)