Skip to content

[NVPTX] Use appropriate operands in ReplaceImageHandles (NFC) #127898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 21 additions & 100 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,67 +149,6 @@ void NVPTXAsmPrinter::emitInstruction(const MachineInstr *MI) {
EmitToStreamer(*OutStreamer, Inst);
}

// Handle symbol backtracking for targets that do not support image handles
bool NVPTXAsmPrinter::lowerImageHandleOperand(const MachineInstr *MI,
unsigned OpNo, MCOperand &MCOp) {
const MachineOperand &MO = MI->getOperand(OpNo);
const MCInstrDesc &MCID = MI->getDesc();

if (MCID.TSFlags & NVPTXII::IsTexFlag) {
// This is a texture fetch, so operand 4 is a texref and operand 5 is
// a samplerref
if (OpNo == 4 && MO.isImm()) {
lowerImageHandleSymbol(MO.getImm(), MCOp);
return true;
}
if (OpNo == 5 && MO.isImm() && !(MCID.TSFlags & NVPTXII::IsTexModeUnifiedFlag)) {
lowerImageHandleSymbol(MO.getImm(), MCOp);
return true;
}

return false;
} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
unsigned VecSize =
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);

// For a surface load of vector size N, the Nth operand will be the surfref
if (OpNo == VecSize && MO.isImm()) {
lowerImageHandleSymbol(MO.getImm(), MCOp);
return true;
}

return false;
} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
// This is a surface store, so operand 0 is a surfref
if (OpNo == 0 && MO.isImm()) {
lowerImageHandleSymbol(MO.getImm(), MCOp);
return true;
}

return false;
} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
// This is a query, so operand 1 is a surfref/texref
if (OpNo == 1 && MO.isImm()) {
lowerImageHandleSymbol(MO.getImm(), MCOp);
return true;
}

return false;
}

return false;
}

void NVPTXAsmPrinter::lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp) {
// Ewwww
TargetMachine &TM = const_cast<TargetMachine &>(MF->getTarget());
NVPTXTargetMachine &nvTM = static_cast<NVPTXTargetMachine &>(TM);
const NVPTXMachineFunctionInfo *MFI = MF->getInfo<NVPTXMachineFunctionInfo>();
StringRef Sym = MFI->getImageHandleSymbol(Index);
StringRef SymName = nvTM.getStrPool().save(Sym);
MCOp = GetSymbolRef(OutContext.getOrCreateSymbol(SymName));
}

void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
OutMI.setOpcode(MI->getOpcode());
// Special: Do not mangle symbol operand of CALL_PROTOTYPE
Expand All @@ -220,67 +159,49 @@ void NVPTXAsmPrinter::lowerToMCInst(const MachineInstr *MI, MCInst &OutMI) {
return;
}

for (unsigned i = 0, e = MI->getNumOperands(); i != e; ++i) {
const MachineOperand &MO = MI->getOperand(i);

MCOperand MCOp;
if (lowerImageHandleOperand(MI, i, MCOp)) {
OutMI.addOperand(MCOp);
continue;
}

if (lowerOperand(MO, MCOp))
OutMI.addOperand(MCOp);
}
for (const auto MO : MI->operands())
OutMI.addOperand(lowerOperand(MO));
}

bool NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO,
MCOperand &MCOp) {
MCOperand NVPTXAsmPrinter::lowerOperand(const MachineOperand &MO) {
switch (MO.getType()) {
default: llvm_unreachable("unknown operand type");
default:
llvm_unreachable("unknown operand type");
case MachineOperand::MO_Register:
MCOp = MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
break;
return MCOperand::createReg(encodeVirtualRegister(MO.getReg()));
case MachineOperand::MO_Immediate:
MCOp = MCOperand::createImm(MO.getImm());
break;
return MCOperand::createImm(MO.getImm());
case MachineOperand::MO_MachineBasicBlock:
MCOp = MCOperand::createExpr(MCSymbolRefExpr::create(
MO.getMBB()->getSymbol(), OutContext));
break;
return MCOperand::createExpr(
MCSymbolRefExpr::create(MO.getMBB()->getSymbol(), OutContext));
case MachineOperand::MO_ExternalSymbol:
MCOp = GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
break;
return GetSymbolRef(GetExternalSymbolSymbol(MO.getSymbolName()));
case MachineOperand::MO_GlobalAddress:
MCOp = GetSymbolRef(getSymbol(MO.getGlobal()));
break;
return GetSymbolRef(getSymbol(MO.getGlobal()));
case MachineOperand::MO_FPImmediate: {
const ConstantFP *Cnt = MO.getFPImm();
const APFloat &Val = Cnt->getValueAPF();

switch (Cnt->getType()->getTypeID()) {
default: report_fatal_error("Unsupported FP type"); break;
case Type::HalfTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
default:
report_fatal_error("Unsupported FP type");
break;
case Type::HalfTyID:
return MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
case Type::BFloatTyID:
MCOp = MCOperand::createExpr(
return MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
break;
case Type::FloatTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
break;
return MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
case Type::DoubleTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
break;
return MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPDouble(Val, OutContext));
}
break;
}
}
return true;
}

unsigned NVPTXAsmPrinter::encodeVirtualRegister(unsigned Reg) {
Expand Down
6 changes: 1 addition & 5 deletions llvm/lib/Target/NVPTX/NVPTXAsmPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {

void emitInstruction(const MachineInstr *) override;
void lowerToMCInst(const MachineInstr *MI, MCInst &OutMI);
bool lowerOperand(const MachineOperand &MO, MCOperand &MCOp);
MCOperand lowerOperand(const MachineOperand &MO);
MCOperand GetSymbolRef(const MCSymbol *Symbol);
unsigned encodeVirtualRegister(unsigned Reg);

Expand Down Expand Up @@ -226,10 +226,6 @@ class LLVM_LIBRARY_VISIBILITY NVPTXAsmPrinter : public AsmPrinter {
void emitDeclarationWithName(const Function *, MCSymbol *, raw_ostream &O);
void emitDemotedVars(const Function *, raw_ostream &);

bool lowerImageHandleOperand(const MachineInstr *MI, unsigned OpNo,
MCOperand &MCOp);
void lowerImageHandleSymbol(unsigned Index, MCOperand &MCOp);

bool isLoopHeaderOfNoUnroll(const MachineBasicBlock &MBB) const;

// Used to control the need to emit .generic() in the initializer of
Expand Down
6 changes: 0 additions & 6 deletions llvm/lib/Target/NVPTX/NVPTXMachineFunctionInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ class NVPTXMachineFunctionInfo : public MachineFunctionInfo {
return ImageHandleList.size()-1;
}

/// Returns the symbol name at the given index.
StringRef getImageHandleSymbol(unsigned Idx) const {
assert(ImageHandleList.size() > Idx && "Bad index");
return ImageHandleList[Idx];
}

/// Check if the symbol has a mapping. Having a mapping means the handle is
/// replaced with a reference
bool checkImageHandleSymbol(StringRef Symbol) const {
Expand Down
51 changes: 16 additions & 35 deletions llvm/lib/Target/NVPTX/NVPTXReplaceImageHandles.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "llvm/CodeGen/MachineFunction.h"
#include "llvm/CodeGen/MachineFunctionPass.h"
#include "llvm/CodeGen/MachineRegisterInfo.h"
#include "llvm/Support/raw_ostream.h"

using namespace llvm;

Expand All @@ -41,10 +40,8 @@ class NVPTXReplaceImageHandles : public MachineFunctionPass {
private:
bool processInstr(MachineInstr &MI);
bool replaceImageHandle(MachineOperand &Op, MachineFunction &MF);
bool findIndexForHandle(MachineOperand &Op, MachineFunction &MF,
unsigned &Idx);
};
}
} // namespace

char NVPTXReplaceImageHandles::ID = 0;

Expand Down Expand Up @@ -1756,9 +1753,11 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
}

return true;
} else if (MCID.TSFlags & NVPTXII::IsSuldMask) {
}
if (MCID.TSFlags & NVPTXII::IsSuldMask) {
unsigned VecSize =
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) - 1);
1 << (((MCID.TSFlags & NVPTXII::IsSuldMask) >> NVPTXII::IsSuldShift) -
1);

// For a surface load of vector size N, the Nth operand will be the surfref
MachineOperand &SurfHandle = MI.getOperand(VecSize);
Expand All @@ -1767,15 +1766,17 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {
MI.setDesc(TII->get(suldRegisterToIndexOpcode(MI.getOpcode())));

return true;
} else if (MCID.TSFlags & NVPTXII::IsSustFlag) {
}
if (MCID.TSFlags & NVPTXII::IsSustFlag) {
// This is a surface store, so operand 0 is a surfref
MachineOperand &SurfHandle = MI.getOperand(0);

if (replaceImageHandle(SurfHandle, MF))
MI.setDesc(TII->get(sustRegisterToIndexOpcode(MI.getOpcode())));

return true;
} else if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
}
if (MCID.TSFlags & NVPTXII::IsSurfTexQueryFlag) {
// This is a query, so operand 1 is a surfref/texref
MachineOperand &Handle = MI.getOperand(1);

Expand All @@ -1790,16 +1791,6 @@ bool NVPTXReplaceImageHandles::processInstr(MachineInstr &MI) {

bool NVPTXReplaceImageHandles::replaceImageHandle(MachineOperand &Op,
MachineFunction &MF) {
unsigned Idx;
if (findIndexForHandle(Op, MF, Idx)) {
Op.ChangeToImmediate(Idx);
return true;
}
return false;
}

bool NVPTXReplaceImageHandles::
findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
const MachineRegisterInfo &MRI = MF.getRegInfo();
NVPTXMachineFunctionInfo *MFI = MF.getInfo<NVPTXMachineFunctionInfo>();

Expand All @@ -1812,25 +1803,16 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
case NVPTX::LD_i64_avar: {
// The handle is a parameter value being loaded, replace with the
// parameter symbol
const NVPTXTargetMachine &TM =
static_cast<const NVPTXTargetMachine &>(MF.getTarget());
if (TM.getDrvInterface() == NVPTX::CUDA) {
const auto &TM = static_cast<const NVPTXTargetMachine &>(MF.getTarget());
if (TM.getDrvInterface() == NVPTX::CUDA)
// For CUDA, we preserve the param loads coming from function arguments
return false;
}

assert(TexHandleDef.getOperand(7).isSymbol() && "Load is not a symbol!");
StringRef Sym = TexHandleDef.getOperand(7).getSymbolName();
std::string ParamBaseName = std::string(MF.getName());
ParamBaseName += "_param_";
assert(Sym.starts_with(ParamBaseName) && "Invalid symbol reference");
unsigned Param = atoi(Sym.data()+ParamBaseName.size());
std::string NewSym;
raw_string_ostream NewSymStr(NewSym);
NewSymStr << MF.getName() << "_param_" << Param;

InstrsToRemove.insert(&TexHandleDef);
Idx = MFI->getImageHandleSymbolIndex(NewSymStr.str());
Op.ChangeToES(Sym.data());
MFI->getImageHandleSymbolIndex(Sym);
return true;
}
case NVPTX::texsurf_handles: {
Expand All @@ -1839,15 +1821,14 @@ findIndexForHandle(MachineOperand &Op, MachineFunction &MF, unsigned &Idx) {
const GlobalValue *GV = TexHandleDef.getOperand(1).getGlobal();
assert(GV->hasName() && "Global sampler must be named!");
InstrsToRemove.insert(&TexHandleDef);
Idx = MFI->getImageHandleSymbolIndex(GV->getName());
Op.ChangeToGA(GV, 0);
return true;
}
case NVPTX::nvvm_move_i64:
case TargetOpcode::COPY: {
bool Res = findIndexForHandle(TexHandleDef.getOperand(1), MF, Idx);
if (Res) {
bool Res = replaceImageHandle(TexHandleDef.getOperand(1), MF);
if (Res)
InstrsToRemove.insert(&TexHandleDef);
}
return Res;
}
default:
Expand Down