Skip to content

[llvm][NVPTX] Fix quadratic runtime in ProxyRegErasure #105730

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 23, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 28 additions & 32 deletions llvm/lib/Target/NVPTX/NVPTXProxyRegErasure.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ void initializeNVPTXProxyRegErasurePass(PassRegistry &);
namespace {

struct NVPTXProxyRegErasure : public MachineFunctionPass {
public:
static char ID;
NVPTXProxyRegErasure() : MachineFunctionPass(ID) {
initializeNVPTXProxyRegErasurePass(*PassRegistry::getPassRegistry());
Expand All @@ -49,23 +48,22 @@ struct NVPTXProxyRegErasure : public MachineFunctionPass {
void getAnalysisUsage(AnalysisUsage &AU) const override {
MachineFunctionPass::getAnalysisUsage(AU);
}

private:
void replaceMachineInstructionUsage(MachineFunction &MF, MachineInstr &MI);

void replaceRegisterUsage(MachineInstr &Instr, MachineOperand &From,
MachineOperand &To);
};

} // namespace

char NVPTXProxyRegErasure::ID = 0;

INITIALIZE_PASS(NVPTXProxyRegErasure, "nvptx-proxyreg-erasure", "NVPTX ProxyReg Erasure", false, false)
INITIALIZE_PASS(NVPTXProxyRegErasure, "nvptx-proxyreg-erasure",
"NVPTX ProxyReg Erasure", false, false)

bool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) {
SmallVector<MachineInstr *, 16> RemoveList;

// ProxyReg instructions forward a register as another: `%dst = mov.iN %src`.
// Bulk RAUW the `%dst` registers in two passes over the machine function.
DenseMap<Register, Register> RAUWBatch;

for (auto &BB : MF) {
for (auto &MI : BB) {
switch (MI.getOpcode()) {
Expand All @@ -74,44 +72,42 @@ bool NVPTXProxyRegErasure::runOnMachineFunction(MachineFunction &MF) {
case NVPTX::ProxyRegI32:
case NVPTX::ProxyRegI64:
case NVPTX::ProxyRegF32:
case NVPTX::ProxyRegF64:
replaceMachineInstructionUsage(MF, MI);
case NVPTX::ProxyRegF64: {
auto &InOp = *MI.uses().begin();
auto &OutOp = *MI.defs().begin();
assert(InOp.isReg() && "ProxyReg input should be a register.");
assert(OutOp.isReg() && "ProxyReg output should be a register.");
RemoveList.push_back(&MI);
RAUWBatch.try_emplace(OutOp.getReg(), InOp.getReg());
break;
}
}
}
}

// If there were no proxy instructions, exit early.
if (RemoveList.empty())
return false;

// Erase the proxy instructions first.
for (auto *MI : RemoveList) {
MI->eraseFromParent();
}

return !RemoveList.empty();
}

void NVPTXProxyRegErasure::replaceMachineInstructionUsage(MachineFunction &MF,
MachineInstr &MI) {
auto &InOp = *MI.uses().begin();
auto &OutOp = *MI.defs().begin();

assert(InOp.isReg() && "ProxyReg input operand should be a register.");
assert(OutOp.isReg() && "ProxyReg output operand should be a register.");

// Now go replace the registers.
for (auto &BB : MF) {
for (auto &I : BB) {
replaceRegisterUsage(I, OutOp, InOp);
for (auto &MI : BB) {
for (auto &Op : MI.uses()) {
if (!Op.isReg())
continue;
auto it = RAUWBatch.find(Op.getReg());
if (it != RAUWBatch.end())
Op.setReg(it->second);
}
}
}
}

void NVPTXProxyRegErasure::replaceRegisterUsage(MachineInstr &Instr,
MachineOperand &From,
MachineOperand &To) {
for (auto &Op : Instr.uses()) {
if (Op.isReg() && Op.getReg() == From.getReg()) {
Op.setReg(To.getReg());
}
}
return true;
}

MachineFunctionPass *llvm::createNVPTXProxyRegErasurePass() {
Expand Down
Loading