Skip to content

Reapply "[RISCV] Implement tail call optimization in machine outliner" #117700

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ static inline unsigned getVLOpNum(const MCInstrDesc &Desc) {
return Desc.getNumOperands() - Offset;
}

static inline unsigned getTailExpandUseRegNo(const FeatureBitset &FeatureBits) {
// For Zicfilp, PseudoTAIL should be expanded to a software guarded branch.
// It means to use t2(x7) as rs1 of JALR to expand PseudoTAIL.
return FeatureBits[RISCV::FeatureStdExtZicfilp] ? RISCV::X7 : RISCV::X6;
}

static inline unsigned getSEWOpNum(const MCInstrDesc &Desc) {
const uint64_t TSFlags = Desc.TSFlags;
assert(hasSEWOp(TSFlags));
Expand Down
6 changes: 1 addition & 5 deletions llvm/lib/Target/RISCV/MCTargetDesc/RISCVMCCodeEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,7 @@ void RISCVMCCodeEmitter::expandFunctionCall(const MCInst &MI,
MCRegister Ra;
if (MI.getOpcode() == RISCV::PseudoTAIL) {
Func = MI.getOperand(0);
Ra = RISCV::X6;
// For Zicfilp, PseudoTAIL should be expanded to a software guarded branch.
// It means to use t2(x7) as rs1 of JALR to expand PseudoTAIL.
if (STI.hasFeature(RISCV::FeatureStdExtZicfilp))
Ra = RISCV::X7;
Ra = RISCVII::getTailExpandUseRegNo(STI.getFeatureBits());
} else if (MI.getOpcode() == RISCV::PseudoCALLReg) {
Func = MI.getOperand(1);
Ra = MI.getOperand(0).getReg();
Expand Down
143 changes: 110 additions & 33 deletions llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
//===----------------------------------------------------------------------===//

#include "RISCVInstrInfo.h"
#include "MCTargetDesc/RISCVBaseInfo.h"
#include "MCTargetDesc/RISCVMatInt.h"
#include "RISCV.h"
#include "RISCVMachineFunctionInfo.h"
Expand Down Expand Up @@ -2927,6 +2928,7 @@ bool RISCVInstrInfo::isMBBSafeToOutlineFrom(MachineBasicBlock &MBB,

// Enum values indicating how an outlined call should be constructed.
enum MachineOutlinerConstructionID {
MachineOutlinerTailCall,
MachineOutlinerDefault
};

Expand All @@ -2935,46 +2937,118 @@ bool RISCVInstrInfo::shouldOutlineFromFunctionByDefault(
return MF.getFunction().hasMinSize();
}

static bool isCandidatePatchable(const MachineBasicBlock &MBB) {
const MachineFunction *MF = MBB.getParent();
const Function &F = MF->getFunction();
return F.getFnAttribute("fentry-call").getValueAsBool() ||
F.hasFnAttribute("patchable-function-entry");
}

static bool isMIReadsReg(const MachineInstr &MI, const TargetRegisterInfo *TRI,
unsigned RegNo) {
return MI.readsRegister(RegNo, TRI) ||
MI.getDesc().hasImplicitUseOfPhysReg(RegNo);
}

static bool isMIModifiesReg(const MachineInstr &MI,
const TargetRegisterInfo *TRI, unsigned RegNo) {
return MI.modifiesRegister(RegNo, TRI) ||
MI.getDesc().hasImplicitDefOfPhysReg(RegNo);
}

static bool cannotInsertTailCall(const MachineBasicBlock &MBB) {
if (!MBB.back().isReturn())
return true;
if (isCandidatePatchable(MBB))
return true;

// If the candidate reads the pre-set register
// that can be used for expanding PseudoTAIL instruction,
// then we cannot insert tail call.
const TargetSubtargetInfo &STI = MBB.getParent()->getSubtarget();
unsigned TailExpandUseRegNo =
RISCVII::getTailExpandUseRegNo(STI.getFeatureBits());
for (const MachineInstr &MI : MBB) {
if (isMIReadsReg(MI, STI.getRegisterInfo(), TailExpandUseRegNo))
return true;
if (isMIModifiesReg(MI, STI.getRegisterInfo(), TailExpandUseRegNo))
break;
}
return false;
}

static std::optional<MachineOutlinerConstructionID>
analyzeCandidate(outliner::Candidate &C) {
// If last instruction is return then we can rely on
// the verification already performed in the getOutliningTypeImpl.
if (C.back().isReturn()) {
assert(!cannotInsertTailCall(*C.getMBB()) &&
"The candidate who uses return instruction must be outlined "
"using tail call");
return MachineOutlinerTailCall;
}

auto CandidateUsesX5 = [](outliner::Candidate &C) {
const TargetRegisterInfo *TRI = C.getMF()->getSubtarget().getRegisterInfo();
if (std::any_of(C.begin(), C.end(), [TRI](const MachineInstr &MI) {
return isMIModifiesReg(MI, TRI, RISCV::X5);
}))
return true;
return !C.isAvailableAcrossAndOutOfSeq(RISCV::X5, *TRI);
};

if (!CandidateUsesX5(C))
return MachineOutlinerDefault;

return std::nullopt;
}

std::optional<std::unique_ptr<outliner::OutlinedFunction>>
RISCVInstrInfo::getOutliningCandidateInfo(
const MachineModuleInfo &MMI,
std::vector<outliner::Candidate> &RepeatedSequenceLocs,
unsigned MinRepeats) const {

// First we need to filter out candidates where the X5 register (IE t0) can't
// be used to setup the function call.
auto CannotInsertCall = [](outliner::Candidate &C) {
const TargetRegisterInfo *TRI = C.getMF()->getSubtarget().getRegisterInfo();
return !C.isAvailableAcrossAndOutOfSeq(RISCV::X5, *TRI);
};

llvm::erase_if(RepeatedSequenceLocs, CannotInsertCall);
// Each RepeatedSequenceLoc is identical.
outliner::Candidate &Candidate = RepeatedSequenceLocs[0];
auto CandidateInfo = analyzeCandidate(Candidate);
if (!CandidateInfo)
RepeatedSequenceLocs.clear();

// If the sequence doesn't have enough candidates left, then we're done.
if (RepeatedSequenceLocs.size() < MinRepeats)
return std::nullopt;

unsigned SequenceSize = 0;

for (auto &MI : RepeatedSequenceLocs[0])
SequenceSize += getInstSizeInBytes(MI);
unsigned InstrSizeCExt =
Candidate.getMF()->getSubtarget<RISCVSubtarget>().hasStdExtCOrZca() ? 2
: 4;
unsigned CallOverhead = 0, FrameOverhead = 0;

MachineOutlinerConstructionID MOCI = CandidateInfo.value();
switch (MOCI) {
case MachineOutlinerDefault:
// call t0, function = 8 bytes.
CallOverhead = 8;
// jr t0 = 4 bytes, 2 bytes if compressed instructions are enabled.
FrameOverhead = InstrSizeCExt;
break;
case MachineOutlinerTailCall:
// tail call = auipc + jalr in the worst case without linker relaxation.
CallOverhead = 4 + InstrSizeCExt;
// Using tail call we move ret instruction from caller to callee.
FrameOverhead = 0;
break;
}

// call t0, function = 8 bytes.
unsigned CallOverhead = 8;
for (auto &C : RepeatedSequenceLocs)
C.setCallInfo(MachineOutlinerDefault, CallOverhead);
C.setCallInfo(MOCI, CallOverhead);

// jr t0 = 4 bytes, 2 bytes if compressed instructions are enabled.
unsigned FrameOverhead = 4;
if (RepeatedSequenceLocs[0]
.getMF()
->getSubtarget<RISCVSubtarget>()
.hasStdExtCOrZca())
FrameOverhead = 2;
unsigned SequenceSize = 0;
for (auto &MI : Candidate)
SequenceSize += getInstSizeInBytes(MI);

return std::make_unique<outliner::OutlinedFunction>(
RepeatedSequenceLocs, SequenceSize, FrameOverhead,
MachineOutlinerDefault);
RepeatedSequenceLocs, SequenceSize, FrameOverhead, MOCI);
}

outliner::InstrType
Expand All @@ -2995,15 +3069,8 @@ RISCVInstrInfo::getOutliningTypeImpl(const MachineModuleInfo &MMI,
return F.needsUnwindTableEntry() ? outliner::InstrType::Illegal
: outliner::InstrType::Invisible;

// We need support for tail calls to outlined functions before return
// statements can be allowed.
if (MI.isReturn())
return outliner::InstrType::Illegal;

// Don't allow modifying the X5 register which we use for return addresses for
// these outlined functions.
if (MI.modifiesRegister(RISCV::X5, TRI) ||
MI.getDesc().hasImplicitDefOfPhysReg(RISCV::X5))
if (cannotInsertTailCall(*MBB) &&
(MI.isReturn() || isMIModifiesReg(MI, TRI, RISCV::X5)))
return outliner::InstrType::Illegal;

// Make sure the operands don't reference something unsafe.
Expand Down Expand Up @@ -3039,6 +3106,9 @@ void RISCVInstrInfo::buildOutlinedFrame(
}
}

if (OF.FrameConstructionID == MachineOutlinerTailCall)
return;

MBB.addLiveIn(RISCV::X5);

// Add in a return instruction to the end of the outlined frame.
Expand All @@ -3052,6 +3122,13 @@ MachineBasicBlock::iterator RISCVInstrInfo::insertOutlinedCall(
Module &M, MachineBasicBlock &MBB, MachineBasicBlock::iterator &It,
MachineFunction &MF, outliner::Candidate &C) const {

if (C.CallConstructionID == MachineOutlinerTailCall) {
It = MBB.insert(It, BuildMI(MF, DebugLoc(), get(RISCV::PseudoTAIL))
.addGlobalAddress(M.getNamedValue(MF.getName()),
/*Offset=*/0, RISCVII::MO_CALL));
return It;
}

// Add in a call instruction to the outlined function at the given location.
It = MBB.insert(It,
BuildMI(MF, DebugLoc(), get(RISCV::PseudoCALLReg), RISCV::X5)
Expand Down
Loading
Loading