Skip to content

[AArch64][PAC] Lower authenticated calls with ptrauth bundles. #85736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ class CallLowering {
ArgInfo() = default;
};

struct PtrAuthInfo {
uint64_t Key;
Register Discriminator;
};

struct CallLoweringInfo {
/// Calling convention to be used for the call.
CallingConv::ID CallConv = CallingConv::C;
Expand All @@ -125,6 +130,9 @@ class CallLowering {

MDNode *KnownCallees = nullptr;

/// The auth-call information in the "ptrauth" bundle, if present.
std::optional<PtrAuthInfo> PAI;

/// True if the call must be tail call optimized.
bool IsMustTailCall = false;

Expand Down Expand Up @@ -587,7 +595,7 @@ class CallLowering {
bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
ArrayRef<Register> ResRegs,
ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
Register ConvergenceCtrlToken,
std::optional<PtrAuthInfo> PAI, Register ConvergenceCtrlToken,
std::function<unsigned()> GetCalleeReg) const;

/// For targets which want to use big-endian can enable it with
Expand Down
18 changes: 18 additions & 0 deletions llvm/include/llvm/CodeGen/TargetLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -4361,6 +4361,9 @@ class TargetLowering : public TargetLoweringBase {
/// Return true if the target supports kcfi operand bundles.
virtual bool supportKCFIBundles() const { return false; }

/// Return true if the target supports ptrauth operand bundles.
virtual bool supportPtrAuthBundles() const { return false; }

/// Perform necessary initialization to handle a subset of CSRs explicitly
/// via copies. This function is called at the beginning of instruction
/// selection.
Expand Down Expand Up @@ -4472,6 +4475,14 @@ class TargetLowering : public TargetLoweringBase {
llvm_unreachable("Not Implemented");
}

/// This structure contains the information necessary for lowering
/// pointer-authenticating indirect calls. It is equivalent to the "ptrauth"
/// operand bundle found on the call instruction, if any.
struct PtrAuthInfo {
uint64_t Key;
SDValue Discriminator;
};

/// This structure contains all information that is necessary for lowering
/// calls. It is passed to TLI::LowerCallTo when the SelectionDAG builder
/// needs to lower a call, and targets will see this struct in their LowerCall
Expand Down Expand Up @@ -4511,6 +4522,8 @@ class TargetLowering : public TargetLoweringBase {
const ConstantInt *CFIType = nullptr;
SDValue ConvergenceControlToken;

std::optional<PtrAuthInfo> PAI;

CallLoweringInfo(SelectionDAG &DAG)
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
DoesNotReturn(false), IsReturnValueUsed(true), IsConvergent(false),
Expand Down Expand Up @@ -4633,6 +4646,11 @@ class TargetLowering : public TargetLoweringBase {
return *this;
}

CallLoweringInfo &setPtrAuth(PtrAuthInfo Value) {
PAI = Value;
return *this;
}

CallLoweringInfo &setIsPostTypeLegalization(bool Value=true) {
IsPostTypeLegalization = Value;
return *this;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/CallLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
ArrayRef<Register> ResRegs,
ArrayRef<ArrayRef<Register>> ArgRegs,
Register SwiftErrorVReg,
std::optional<PtrAuthInfo> PAI,
Register ConvergenceCtrlToken,
std::function<unsigned()> GetCalleeReg) const {
CallLoweringInfo Info;
Expand Down Expand Up @@ -188,6 +189,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
Info.CallConv = CallConv;
Info.SwiftErrorVReg = SwiftErrorVReg;
Info.PAI = PAI;
Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
Info.IsMustTailCall = CB.isMustTailCall();
Info.IsTailCall = CanBeTailCalled;
Expand Down
16 changes: 15 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2642,6 +2642,20 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
}
}

std::optional<CallLowering::PtrAuthInfo> PAI;
if (CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
// Functions should never be ptrauth-called directly.
assert(!CB.getCalledFunction() && "invalid direct ptrauth call");

auto PAB = CB.getOperandBundle("ptrauth");
const Value *Key = PAB->Inputs[0];
const Value *Discriminator = PAB->Inputs[1];

Register DiscReg = getOrCreateVReg(*Discriminator);
PAI = CallLowering::PtrAuthInfo{cast<ConstantInt>(Key)->getZExtValue(),
DiscReg};
}

Register ConvergenceCtrlToken = 0;
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl)) {
const auto &Token = *Bundle->Inputs[0].get();
Expand All @@ -2652,7 +2666,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
// optimize into tail calls. Instead, we defer that to selection where a final
// scan is done to check if any instructions are calls.
bool Success = CLI->lowerCall(
MIRBuilder, CB, Res, Args, SwiftErrorVReg, ConvergenceCtrlToken,
MIRBuilder, CB, Res, Args, SwiftErrorVReg, PAI, ConvergenceCtrlToken,
[&]() { return getOrCreateVReg(*CB.getCalledOperand()); });

// Check if we just inserted a tail call.
Expand Down
51 changes: 46 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3307,12 +3307,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
const BasicBlock *EHPadBB = I.getSuccessor(1);
MachineBasicBlock *EHPadMBB = FuncInfo.MBBMap[EHPadBB];

// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
// Deopt and ptrauth bundles are lowered in helper functions, and we don't
// have to do anything here to lower funclet bundles.
assert(!I.hasOperandBundlesOtherThan(
{LLVMContext::OB_deopt, LLVMContext::OB_gc_transition,
LLVMContext::OB_gc_live, LLVMContext::OB_funclet,
LLVMContext::OB_cfguardtarget,
LLVMContext::OB_cfguardtarget, LLVMContext::OB_ptrauth,
LLVMContext::OB_clang_arc_attachedcall}) &&
"Cannot lower invokes with arbitrary operand bundles yet!");

Expand Down Expand Up @@ -3363,6 +3363,8 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
// intrinsic, and right now there are no plans to support other intrinsics
// with deopt state.
LowerCallSiteWithDeoptBundle(&I, getValue(Callee), EHPadBB);
} else if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), EHPadBB);
} else {
LowerCallTo(I, getValue(Callee), false, false, EHPadBB);
}
Expand Down Expand Up @@ -8531,9 +8533,9 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
}

void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
bool isTailCall,
bool isMustTailCall,
const BasicBlock *EHPadBB) {
bool isTailCall, bool isMustTailCall,
const BasicBlock *EHPadBB,
const TargetLowering::PtrAuthInfo *PAI) {
auto &DL = DAG.getDataLayout();
FunctionType *FTy = CB.getFunctionType();
Type *RetTy = CB.getType();
Expand Down Expand Up @@ -8640,6 +8642,15 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
.setCFIType(CFIType)
.setConvergenceControlToken(ConvControlToken);

// Set the pointer authentication info if we have it.
if (PAI) {
if (!TLI.supportPtrAuthBundles())
report_fatal_error(
"This target doesn't support calls with ptrauth operand bundles.");
CLI.setPtrAuth(*PAI);
}

std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);

if (Result.first.getNode()) {
Expand Down Expand Up @@ -9185,6 +9196,11 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
}
}

if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), /*EHPadBB=*/nullptr);
return;
}

// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
// have to do anything here to lower funclet bundles.
// CFGuardTarget bundles are lowered in LowerCallTo.
Expand All @@ -9206,6 +9222,31 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
LowerCallTo(I, Callee, I.isTailCall(), I.isMustTailCall());
}

void SelectionDAGBuilder::LowerCallSiteWithPtrAuthBundle(
const CallBase &CB, const BasicBlock *EHPadBB) {
auto PAB = CB.getOperandBundle("ptrauth");
const Value *CalleeV = CB.getCalledOperand();

// Gather the call ptrauth data from the operand bundle:
// [ i32 <key>, i64 <discriminator> ]
const auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
const Value *Discriminator = PAB->Inputs[1];

assert(Key->getType()->isIntegerTy(32) && "Invalid ptrauth key");
assert(Discriminator->getType()->isIntegerTy(64) &&
"Invalid ptrauth discriminator");

// Functions should never be ptrauth-called directly.
assert(!isa<Function>(CalleeV) && "invalid direct ptrauth call");

// Otherwise, do an authenticated indirect call.
TargetLowering::PtrAuthInfo PAI = {Key->getZExtValue(),
getValue(Discriminator)};

LowerCallTo(CB, getValue(CalleeV), CB.isTailCall(), CB.isMustTailCall(),
EHPadBB, &PAI);
}

namespace {

/// AsmOperandInfo - This contains information for each constraint that we are
Expand Down
6 changes: 5 additions & 1 deletion llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,8 @@ class SelectionDAGBuilder {
void CopyToExportRegsIfNeeded(const Value *V);
void ExportFromCurrentBlock(const Value *V);
void LowerCallTo(const CallBase &CB, SDValue Callee, bool IsTailCall,
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr);
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
const TargetLowering::PtrAuthInfo *PAI = nullptr);

// Lower range metadata from 0 to N to assert zext to an integer of nearest
// floor power of two.
Expand Down Expand Up @@ -490,6 +491,9 @@ class SelectionDAGBuilder {
bool VarArgDisallowed,
bool ForceVoidReturnTy);

void LowerCallSiteWithPtrAuthBundle(const CallBase &CB,
const BasicBlock *EHPadBB);

/// Returns the type of FrameIndex and TargetFrameIndex nodes.
MVT getFrameIndexTy() {
return DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout());
Expand Down
132 changes: 132 additions & 0 deletions llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,12 @@ class AArch64AsmPrinter : public AsmPrinter {

void emitSled(const MachineInstr &MI, SledKind Kind);

// Emit the sequence for BLRA (authenticate + branch).
void emitPtrauthBranch(const MachineInstr *MI);
// Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
unsigned &InstsEmitted);

/// tblgen'erated driver function for lowering simple MI->MC
/// pseudo instructions.
bool emitPseudoExpansionLowering(MCStreamer &OutStreamer,
Expand Down Expand Up @@ -1504,6 +1510,78 @@ void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
}
}

unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
unsigned AddrDisc,
unsigned &InstsEmitted) {
// So far we've used NoRegister in pseudos. Now we need real encodings.
if (AddrDisc == AArch64::NoRegister)
AddrDisc = AArch64::XZR;

// If there is no constant discriminator, there's no blend involved:
// just use the address discriminator register as-is (XZR or not).
if (!Disc)
return AddrDisc;

// If there's only a constant discriminator, MOV it into x17.
if (AddrDisc == AArch64::XZR) {
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
.addReg(AArch64::X17)
.addImm(Disc)
.addImm(/*shift=*/0));
++InstsEmitted;
return AArch64::X17;
}

// If there are both, emit a blend into x17.
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
.addReg(AArch64::X17)
.addReg(AArch64::XZR)
.addReg(AddrDisc)
.addImm(0));
++InstsEmitted;
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
.addReg(AArch64::X17)
.addReg(AArch64::X17)
.addImm(Disc)
.addImm(/*shift=*/48));
++InstsEmitted;
return AArch64::X17;
}

void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
unsigned InstsEmitted = 0;
unsigned BrTarget = MI->getOperand(0).getReg();

auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
"Invalid auth call key");

uint64_t Disc = MI->getOperand(2).getImm();
assert(isUInt<16>(Disc));

unsigned AddrDisc = MI->getOperand(3).getReg();

// Compute discriminator into x17
unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, InstsEmitted);
bool IsZeroDisc = DiscReg == AArch64::XZR;

unsigned Opc;
if (Key == AArch64PACKey::IA)
Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
else
Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;

MCInst BRInst;
BRInst.setOpcode(Opc);
BRInst.addOperand(MCOperand::createReg(BrTarget));
if (!IsZeroDisc)
BRInst.addOperand(MCOperand::createReg(DiscReg));
EmitToStreamer(*OutStreamer, BRInst);
++InstsEmitted;

assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
}

// Simple pseudo-instructions have their lowering (with expansion to real
// instructions) auto-generated.
#include "AArch64GenMCPseudoLowering.inc"
Expand Down Expand Up @@ -1639,9 +1717,63 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
return;
}

case AArch64::BLRA:
emitPtrauthBranch(MI);
return;

// Tail calls use pseudo instructions so they have the proper code-gen
// attributes (isCall, isReturn, etc.). We lower them to the real
// instruction here.
case AArch64::AUTH_TCRETURN:
case AArch64::AUTH_TCRETURN_BTI: {
const uint64_t Key = MI->getOperand(2).getImm();
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
"Invalid auth key for tail-call return");

const uint64_t Disc = MI->getOperand(3).getImm();
assert(isUInt<16>(Disc) && "Integer discriminator is too wide");

Register AddrDisc = MI->getOperand(4).getReg();

Register ScratchReg = MI->getOperand(0).getReg() == AArch64::X16
? AArch64::X17
: AArch64::X16;

unsigned DiscReg = AddrDisc;
if (Disc) {
if (AddrDisc != AArch64::NoRegister) {
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
.addReg(ScratchReg)
.addReg(AArch64::XZR)
.addReg(AddrDisc)
.addImm(0));
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
.addReg(ScratchReg)
.addReg(ScratchReg)
.addImm(Disc)
.addImm(/*shift=*/48));
} else {
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
.addReg(ScratchReg)
.addImm(Disc)
.addImm(/*shift=*/0));
}
DiscReg = ScratchReg;
}

const bool IsZero = DiscReg == AArch64::NoRegister;
const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ},
{AArch64::BRAB, AArch64::BRABZ}};

MCInst TmpInst;
TmpInst.setOpcode(Opcodes[Key][IsZero]);
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
if (!IsZero)
TmpInst.addOperand(MCOperand::createReg(DiscReg));
EmitToStreamer(*OutStreamer, TmpInst);
return;
}

case AArch64::TCRETURNri:
case AArch64::TCRETURNrix16x17:
case AArch64::TCRETURNrix17:
Expand Down
Loading
Loading