Skip to content

[x64][win] Add compiler support for x64 import call optimization (equivalent to MSVC /d2guardretpoline) #126631

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llvm/include/llvm/Transforms/CFGuard.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
namespace llvm {

class FunctionPass;
class GlobalValue;

class CFGuardPass : public PassInfoMixin<CFGuardPass> {
public:
Expand All @@ -34,6 +35,8 @@ FunctionPass *createCFGuardCheckPass();
/// Insert Control FLow Guard dispatches on indirect function calls.
FunctionPass *createCFGuardDispatchPass();

bool isCFGuardFunction(const GlobalValue *GV);

} // namespace llvm

#endif
5 changes: 5 additions & 0 deletions llvm/lib/MC/MCObjectFileInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,11 @@ void MCObjectFileInfo::initCOFFMCObjectFileInfo(const Triple &T) {
if (T.getArch() == Triple::aarch64) {
ImportCallSection =
Ctx->getCOFFSection(".impcall", COFF::IMAGE_SCN_LNK_INFO);
} else if (T.getArch() == Triple::x86_64) {
// Import Call Optimization on x64 leverages the same metadata as the
// retpoline mitigation, hence the unusual section name.
ImportCallSection =
Ctx->getCOFFSection(".retplne", COFF::IMAGE_SCN_LNK_INFO);
}

// Debug info.
Expand Down
35 changes: 34 additions & 1 deletion llvm/lib/Target/X86/X86AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ static bool isIndirectBranchOrTailCall(const MachineInstr &MI) {
Opc == X86::TAILJMPr64 || Opc == X86::TAILJMPm64 ||
Opc == X86::TCRETURNri || Opc == X86::TCRETURNmi ||
Opc == X86::TCRETURNri64 || Opc == X86::TCRETURNmi64 ||
Opc == X86::TAILJMPr64_REX || Opc == X86::TAILJMPm64_REX;
Opc == X86::TCRETURNri64_ImpCall || Opc == X86::TAILJMPr64_REX ||
Opc == X86::TAILJMPm64_REX;
}

void X86AsmPrinter::emitBasicBlockEnd(const MachineBasicBlock &MBB) {
Expand Down Expand Up @@ -912,6 +913,9 @@ void X86AsmPrinter::emitStartOfAsmFile(Module &M) {
if (TT.isOSBinFormatCOFF()) {
emitCOFFFeatureSymbol(M);
emitCOFFReplaceableFunctionData(M);

if (M.getModuleFlag("import-call-optimization"))
EnableImportCallOptimization = true;
}
OutStreamer->emitSyntaxDirective();

Expand Down Expand Up @@ -1016,6 +1020,35 @@ void X86AsmPrinter::emitEndOfAsmFile(Module &M) {
// safe to set.
OutStreamer->emitSubsectionsViaSymbols();
} else if (TT.isOSBinFormatCOFF()) {
// If import call optimization is enabled, emit the appropriate section.
// We do this whether or not we recorded any items.
if (EnableImportCallOptimization) {
OutStreamer->switchSection(getObjFileLowering().getImportCallSection());

// Section always starts with some magic.
constexpr char ImpCallMagic[12] = "RetpolineV1";
OutStreamer->emitBytes(StringRef{ImpCallMagic, sizeof(ImpCallMagic)});

// Layout of this section is:
// Per section that contains an item to record:
// uint32_t SectionSize: Size in bytes for information in this section.
// uint32_t Section Number
// Per call to imported function in section:
// uint32_t Kind: the kind of item.
// uint32_t InstOffset: the offset of the instr in its parent section.
for (auto &[Section, CallsToImportedFuncs] :
SectionToImportedFunctionCalls) {
unsigned SectionSize =
sizeof(uint32_t) * (2 + 2 * CallsToImportedFuncs.size());
OutStreamer->emitInt32(SectionSize);
OutStreamer->emitCOFFSecNumber(Section->getBeginSymbol());
for (auto &[CallsiteSymbol, Kind] : CallsToImportedFuncs) {
OutStreamer->emitInt32(Kind);
OutStreamer->emitCOFFSecOffset(CallsiteSymbol);
}
}
}

if (usesMSVCFloatingPoint(TT, M)) {
// In Windows' libcmt.lib, there is a file which is linked in only if the
// symbol _fltused is referenced. Linking this in causes some
Expand Down
28 changes: 27 additions & 1 deletion llvm/lib/Target/X86/X86AsmPrinter.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
bool EmitFPOData = false;
bool ShouldEmitWeakSwiftAsyncExtendedFramePointerFlags = false;
bool IndCSPrefix = false;
bool EnableImportCallOptimization = false;

enum ImportCallKind : unsigned {
IMAGE_RETPOLINE_AMD64_IMPORT_BR = 0x02,
IMAGE_RETPOLINE_AMD64_IMPORT_CALL = 0x03,
IMAGE_RETPOLINE_AMD64_INDIR_BR = 0x04,
IMAGE_RETPOLINE_AMD64_INDIR_CALL = 0x05,
IMAGE_RETPOLINE_AMD64_INDIR_BR_REX = 0x06,
IMAGE_RETPOLINE_AMD64_CFG_BR = 0x08,
IMAGE_RETPOLINE_AMD64_CFG_CALL = 0x09,
IMAGE_RETPOLINE_AMD64_CFG_BR_REX = 0x0A,
IMAGE_RETPOLINE_AMD64_SWITCHTABLE_FIRST = 0x010,
IMAGE_RETPOLINE_AMD64_SWITCHTABLE_LAST = 0x01F,
};
struct ImportCallInfo {
MCSymbol *CalleeSymbol;
ImportCallKind Kind;
};
DenseMap<MCSection *, std::vector<ImportCallInfo>>
SectionToImportedFunctionCalls;

// This utility class tracks the length of a stackmap instruction's 'shadow'.
// It is used by the X86AsmPrinter to ensure that the stackmap shadow
Expand All @@ -49,7 +69,7 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
void startFunction(MachineFunction &MF) {
this->MF = &MF;
}
void count(MCInst &Inst, const MCSubtargetInfo &STI,
void count(const MCInst &Inst, const MCSubtargetInfo &STI,
MCCodeEmitter *CodeEmitter);

// Called to signal the start of a shadow of RequiredSize bytes.
Expand Down Expand Up @@ -130,6 +150,12 @@ class LLVM_LIBRARY_VISIBILITY X86AsmPrinter : public AsmPrinter {
void emitMachOIFuncStubHelperBody(Module &M, const GlobalIFunc &GI,
MCSymbol *LazyPointer) override;

void emitCallInstruction(const llvm::MCInst &MCI);

// Emits a label to mark the next instruction as being relevant to Import Call
// Optimization.
void emitLabelAndRecordForImportCallOptimization(ImportCallKind Kind);

public:
X86AsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer);

Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/X86/X86ExpandPseudo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::TCRETURNdi64:
case X86::TCRETURNdi64cc:
case X86::TCRETURNri64:
case X86::TCRETURNri64_ImpCall:
case X86::TCRETURNmi64: {
bool isMem = Opcode == X86::TCRETURNmi || Opcode == X86::TCRETURNmi64;
MachineOperand &JumpTarget = MBBI->getOperand(0);
Expand Down Expand Up @@ -345,12 +346,14 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
MachineInstrBuilder MIB = BuildMI(MBB, MBBI, DL, TII->get(Op));
for (unsigned i = 0; i != X86::AddrNumOperands; ++i)
MIB.add(MBBI->getOperand(i));
} else if (Opcode == X86::TCRETURNri64) {
} else if ((Opcode == X86::TCRETURNri64) ||
(Opcode == X86::TCRETURNri64_ImpCall)) {
JumpTarget.setIsKill();
BuildMI(MBB, MBBI, DL,
TII->get(IsWin64 ? X86::TAILJMPr64_REX : X86::TAILJMPr64))
.add(JumpTarget);
} else {
assert(!IsWin64 && "Win64 requires REX for indirect jumps.");
JumpTarget.setIsKill();
BuildMI(MBB, MBBI, DL, TII->get(X86::TAILJMPr))
.add(JumpTarget);
Expand Down Expand Up @@ -875,6 +878,9 @@ bool X86ExpandPseudo::expandMI(MachineBasicBlock &MBB,
case X86::CALL64m_RVMARKER:
expandCALL_RVMARKER(MBB, MBBI);
return true;
case X86::CALL64r_ImpCall:
MI.setDesc(TII->get(X86::CALL64r));
return true;
case X86::ADD32mi_ND:
case X86::ADD64mi32_ND:
case X86::SUB32mi_ND:
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/X86/X86FastISel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicInst.h"
#include "llvm/IR/IntrinsicsX86.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Operator.h"
#include "llvm/MC/MCAsmInfo.h"
#include "llvm/MC/MCSymbol.h"
Expand Down Expand Up @@ -3316,6 +3317,11 @@ bool X86FastISel::fastLowerCall(CallLoweringInfo &CLI) {
if (Flag.isSwiftError() || Flag.isPreallocated())
return false;

// Can't handle import call optimization.
if (Is64Bit &&
MF->getFunction().getParent()->getModuleFlag("import-call-optimization"))
return false;

SmallVector<MVT, 16> OutVTs;
SmallVector<Register, 16> ArgRegs;

Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/X86/X86FrameLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2399,7 +2399,8 @@ X86FrameLowering::getWinEHFuncletFrameSize(const MachineFunction &MF) const {
static bool isTailCallOpcode(unsigned Opc) {
return Opc == X86::TCRETURNri || Opc == X86::TCRETURNdi ||
Opc == X86::TCRETURNmi || Opc == X86::TCRETURNri64 ||
Opc == X86::TCRETURNdi64 || Opc == X86::TCRETURNmi64;
Opc == X86::TCRETURNri64_ImpCall || Opc == X86::TCRETURNdi64 ||
Opc == X86::TCRETURNmi64;
}

void X86FrameLowering::emitEpilogue(MachineFunction &MF,
Expand Down
19 changes: 16 additions & 3 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19179,7 +19179,7 @@ SDValue X86TargetLowering::LowerJumpTable(SDValue Op, SelectionDAG &DAG) const {

SDValue X86TargetLowering::LowerExternalSymbol(SDValue Op,
SelectionDAG &DAG) const {
return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
}

SDValue
Expand Down Expand Up @@ -19207,7 +19207,8 @@ X86TargetLowering::LowerBlockAddress(SDValue Op, SelectionDAG &DAG) const {
/// Creates target global address or external symbol nodes for calls or
/// other uses.
SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
bool ForCall) const {
bool ForCall,
bool *IsImpCall) const {
// Unpack the global address or external symbol.
SDLoc dl(Op);
const GlobalValue *GV = nullptr;
Expand Down Expand Up @@ -19257,6 +19258,16 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
if (ForCall && !NeedsLoad && !HasPICReg && Offset == 0)
return Result;

// If Import Call Optimization is enabled and this is an imported function
// then make a note of it and return the global address without wrapping.
if (IsImpCall && (OpFlags == X86II::MO_DLLIMPORT) &&
Mod.getModuleFlag("import-call-optimization")) {
assert(ForCall && "Should only enable import call optimization if we are "
"lowering a call");
*IsImpCall = true;
return Result;
}

Result = DAG.getNode(getGlobalWrapperKind(GV, OpFlags), dl, PtrVT, Result);

// With PIC, the address is actually $g + Offset.
Expand All @@ -19282,7 +19293,7 @@ SDValue X86TargetLowering::LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,

SDValue
X86TargetLowering::LowerGlobalAddress(SDValue Op, SelectionDAG &DAG) const {
return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false);
return LowerGlobalOrExternal(Op, DAG, /*ForCall=*/false, nullptr);
}

static SDValue GetTLSADDR(SelectionDAG &DAG, GlobalAddressSDNode *GA,
Expand Down Expand Up @@ -34821,6 +34832,7 @@ const char *X86TargetLowering::getTargetNodeName(unsigned Opcode) const {
NODE_NAME_CASE(FST)
NODE_NAME_CASE(CALL)
NODE_NAME_CASE(CALL_RVMARKER)
NODE_NAME_CASE(IMP_CALL)
NODE_NAME_CASE(BT)
NODE_NAME_CASE(CMP)
NODE_NAME_CASE(FCMP)
Expand Down Expand Up @@ -62092,6 +62104,7 @@ X86TargetLowering::EmitKCFICheck(MachineBasicBlock &MBB,
Register TargetReg;
switch (MBBI->getOpcode()) {
case X86::CALL64r:
case X86::CALL64r_ImpCall:
case X86::CALL64r_NT:
case X86::TAILJMPr64:
case X86::TAILJMPr64_REX:
Expand Down
8 changes: 6 additions & 2 deletions llvm/lib/Target/X86/X86ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ namespace llvm {
/// POP_FROM_X87_REG (which may remove a required FPU stack pop).
POP_FROM_X87_REG,

// Pseudo for a call to an imported function to ensure the correct machine
// instruction is emitted for Import Call Optimization.
IMP_CALL,

/// X86 compare and logical compare instructions.
CMP,
FCMP,
Expand Down Expand Up @@ -1746,8 +1750,8 @@ namespace llvm {

/// Creates target global address or external symbol nodes for calls or
/// other uses.
SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG,
bool ForCall) const;
SDValue LowerGlobalOrExternal(SDValue Op, SelectionDAG &DAG, bool ForCall,
bool *IsImpCall) const;

SDValue LowerSINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
SDValue LowerUINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
Expand Down
13 changes: 11 additions & 2 deletions llvm/lib/Target/X86/X86ISelLoweringCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2050,6 +2050,12 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (CallConv == CallingConv::X86_INTR)
report_fatal_error("X86 interrupts may not be called directly");

if (IsIndirectCall && !IsWin64 &&
M->getModuleFlag("import-call-optimization"))
errorUnsupported(DAG, dl,
"Indirect calls must have a normal calling convention if "
"Import Call Optimization is enabled");

// Analyze operands of the call, assigning locations to each operand.
SmallVector<CCValAssign, 16> ArgLocs;
CCState CCInfo(CallConv, isVarArg, MF, ArgLocs, *DAG.getContext());
Expand Down Expand Up @@ -2421,6 +2427,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
InGlue = Chain.getValue(1);
}

bool IsImpCall = false;
if (DAG.getTarget().getCodeModel() == CodeModel::Large) {
assert(Is64Bit && "Large code model is only legal in 64-bit mode.");
// In the 64-bit large code model, we have to make all calls
Expand All @@ -2433,7 +2440,7 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// ForCall to true here has the effect of removing WrapperRIP when possible
// to allow direct calls to be selected without first materializing the
// address into a register.
Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true);
Callee = LowerGlobalOrExternal(Callee, DAG, /*ForCall=*/true, &IsImpCall);
} else if (Subtarget.isTarget64BitILP32() &&
Callee.getValueType() == MVT::i32) {
// Zero-extend the 32-bit Callee address into a 64-bit according to x32 ABI
Expand Down Expand Up @@ -2555,7 +2562,9 @@ X86TargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,

// Returns a chain & a glue for retval copy to use.
SDVTList NodeTys = DAG.getVTList(MVT::Other, MVT::Glue);
if (IsNoTrackIndirectCall) {
if (IsImpCall) {
Chain = DAG.getNode(X86ISD::IMP_CALL, dl, NodeTys, Ops);
} else if (IsNoTrackIndirectCall) {
Chain = DAG.getNode(X86ISD::NT_CALL, dl, NodeTys, Ops);
} else if (CLI.CB && objcarc::hasAttachedCallOpBundle(CLI.CB)) {
// Calls with a "clang.arc.attachedcall" bundle are special. They should be
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/X86/X86InstrCompiler.td
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,8 @@ def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 texternalsym:$dst)),
def : Pat<(X86call_rvmarker (i64 tglobaladdr:$rvfunc), (i64 tglobaladdr:$dst)),
(CALL64pcrel32_RVMARKER tglobaladdr:$rvfunc, tglobaladdr:$dst)>;

def : Pat<(X86imp_call (i64 tglobaladdr:$dst)),
(CALL64pcrel32 tglobaladdr:$dst)>;

// Tailcall stuff. The TCRETURN instructions execute after the epilog, so they
// can never use callee-saved registers. That is the purpose of the GR64_TC
Expand Down Expand Up @@ -1344,7 +1346,11 @@ def : Pat<(X86tcret (i32 texternalsym:$dst), timm:$off),

def : Pat<(X86tcret ptr_rc_tailcall:$dst, timm:$off),
(TCRETURNri64 ptr_rc_tailcall:$dst, timm:$off)>,
Requires<[In64BitMode, NotUseIndirectThunkCalls]>;
Requires<[In64BitMode, NotUseIndirectThunkCalls, ImportCallOptimizationDisabled]>;

def : Pat<(X86tcret ptr_rc_tailcall:$dst, timm:$off),
(TCRETURNri64_ImpCall ptr_rc_tailcall:$dst, timm:$off)>,
Requires<[In64BitMode, NotUseIndirectThunkCalls, ImportCallOptimizationEnabled]>;

// Don't fold loads into X86tcret requiring more than 6 regs.
// There wouldn't be enough scratch registers for base+index.
Expand Down
10 changes: 9 additions & 1 deletion llvm/lib/Target/X86/X86InstrControl.td
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ let isCall = 1, Uses = [RSP, SSP], SchedRW = [WriteJump] in {
Requires<[In64BitMode]>;
def CALL64r : I<0xFF, MRM2r, (outs), (ins GR64:$dst),
"call{q}\t{*}$dst", [(X86call GR64:$dst)]>,
Requires<[In64BitMode,NotUseIndirectThunkCalls]>;
Requires<[In64BitMode,NotUseIndirectThunkCalls,ImportCallOptimizationDisabled]>;
def CALL64m : I<0xFF, MRM2m, (outs), (ins i64mem:$dst),
"call{q}\t{*}$dst", [(X86call (loadi64 addr:$dst))]>,
Requires<[In64BitMode,FavorMemIndirectCall,
Expand Down Expand Up @@ -357,6 +357,10 @@ let isCall = 1, isTerminator = 1, isReturn = 1, isBarrier = 1,
def TCRETURNri64 : PseudoI<(outs),
(ins ptr_rc_tailcall:$dst, i32imm:$offset),
[]>, Sched<[WriteJump]>;
def TCRETURNri64_ImpCall : PseudoI<(outs),
(ins GR64_A:$dst, i32imm:$offset),
[]>, Sched<[WriteJump]>;

let mayLoad = 1 in
def TCRETURNmi64 : PseudoI<(outs),
(ins i64mem_TC:$dst, i32imm:$offset),
Expand Down Expand Up @@ -418,6 +422,10 @@ let isPseudo = 1, isCall = 1, isCodeGenOnly = 1,
def CALL64pcrel32_RVMARKER :
PseudoI<(outs), (ins i64imm:$rvfunc, i64i32imm_brtarget:$dst), []>,
Requires<[In64BitMode]>;

def CALL64r_ImpCall :
PseudoI<(outs), (ins GR64_A:$dst), [(X86call GR64_A:$dst)]>,
Requires<[In64BitMode,NotUseIndirectThunkCalls,ImportCallOptimizationEnabled]>;
}

// Conditional tail calls are similar to the above, but they are branches
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/X86/X86InstrFragments.td
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,9 @@ def X86call_rvmarker : SDNode<"X86ISD::CALL_RVMARKER", SDT_X86Call,
[SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
SDNPVariadic]>;

def X86imp_call : SDNode<"X86ISD::IMP_CALL", SDT_X86Call,
[SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
SDNPVariadic]>;

def X86NoTrackCall : SDNode<"X86ISD::NT_CALL", SDT_X86Call,
[SDNPHasChain, SDNPOutGlue, SDNPOptInGlue,
Expand Down
Loading