Skip to content

Commit 93755da

Browse files
committed
[AArch64][PAC] Lower authenticated calls with ptrauth bundles.
This adds codegen support for the "ptrauth" operand bundles, which can be used to augment indirect calls with the equivalent of an `@llvm.ptrauth.auth` intrinsic call on the call target (possibly preceded by an `@llvm.ptrauth.blend` on the auth discriminator if applicable.) This allows the generation of combined authenticating calls on AArch64 (in the BLRA* PAuth instructions), while avoiding the raw just-authenticated function pointer from being exposed to attackers. This is done by threading a PtrAuthInfo descriptor through the call lowering infrastructure. Note that this also applies to the other forms of indirect calls, notably invokes, rvmarker, and tail calls. Tail-calls in particular bring some additional complexity, with the intersecting register constraints of BTI and PAC discriminator computation.
1 parent 05a7f0e commit 93755da

21 files changed

+1364
-58
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ class CallLowering {
9999
ArgInfo() = default;
100100
};
101101

102+
struct PointerAuthInfo {
103+
Register Discriminator;
104+
uint64_t Key;
105+
};
106+
102107
struct CallLoweringInfo {
103108
/// Calling convention to be used for the call.
104109
CallingConv::ID CallConv = CallingConv::C;
@@ -125,6 +130,8 @@ class CallLowering {
125130

126131
MDNode *KnownCallees = nullptr;
127132

133+
std::optional<PointerAuthInfo> PAI;
134+
128135
/// True if the call must be tail call optimized.
129136
bool IsMustTailCall = false;
130137

@@ -587,6 +594,7 @@ class CallLowering {
587594
bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
588595
ArrayRef<Register> ResRegs,
589596
ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
597+
std::optional<PointerAuthInfo> PAI,
590598
Register ConvergenceCtrlToken,
591599
std::function<unsigned()> GetCalleeReg) const;
592600

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4361,6 +4361,9 @@ class TargetLowering : public TargetLoweringBase {
43614361
/// Return true if the target supports kcfi operand bundles.
43624362
virtual bool supportKCFIBundles() const { return false; }
43634363

4364+
/// Return true if the target supports ptrauth operand bundles.
4365+
virtual bool supportPtrAuthBundles() const { return false; }
4366+
43644367
/// Perform necessary initialization to handle a subset of CSRs explicitly
43654368
/// via copies. This function is called at the beginning of instruction
43664369
/// selection.
@@ -4472,6 +4475,14 @@ class TargetLowering : public TargetLoweringBase {
44724475
llvm_unreachable("Not Implemented");
44734476
}
44744477

4478+
/// This structure contains the information necessary for lowering
4479+
/// pointer-authenticating indirect calls. It is equivalent to the "ptrauth"
4480+
/// operand bundle found on the call instruction, if any.
4481+
struct PtrAuthInfo {
4482+
uint64_t Key;
4483+
SDValue Discriminator;
4484+
};
4485+
44754486
/// This structure contains all information that is necessary for lowering
44764487
/// calls. It is passed to TLI::LowerCallTo when the SelectionDAG builder
44774488
/// needs to lower a call, and targets will see this struct in their LowerCall
@@ -4511,6 +4522,8 @@ class TargetLowering : public TargetLoweringBase {
45114522
const ConstantInt *CFIType = nullptr;
45124523
SDValue ConvergenceControlToken;
45134524

4525+
std::optional<PtrAuthInfo> PAI;
4526+
45144527
CallLoweringInfo(SelectionDAG &DAG)
45154528
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
45164529
DoesNotReturn(false), IsReturnValueUsed(true), IsConvergent(false),
@@ -4633,6 +4646,11 @@ class TargetLowering : public TargetLoweringBase {
46334646
return *this;
46344647
}
46354648

4649+
CallLoweringInfo &setPtrAuth(PtrAuthInfo Value) {
4650+
PAI = Value;
4651+
return *this;
4652+
}
4653+
46364654
CallLoweringInfo &setIsPostTypeLegalization(bool Value=true) {
46374655
IsPostTypeLegalization = Value;
46384656
return *this;

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
9292
ArrayRef<Register> ResRegs,
9393
ArrayRef<ArrayRef<Register>> ArgRegs,
9494
Register SwiftErrorVReg,
95+
std::optional<PointerAuthInfo> PAI,
9596
Register ConvergenceCtrlToken,
9697
std::function<unsigned()> GetCalleeReg) const {
9798
CallLoweringInfo Info;
@@ -188,6 +189,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
188189
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
189190
Info.CallConv = CallConv;
190191
Info.SwiftErrorVReg = SwiftErrorVReg;
192+
Info.PAI = PAI;
191193
Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
192194
Info.IsMustTailCall = CB.isMustTailCall();
193195
Info.IsTailCall = CanBeTailCalled;

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2642,6 +2642,20 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
26422642
}
26432643
}
26442644

2645+
std::optional<CallLowering::PointerAuthInfo> PAI;
2646+
if (CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
2647+
// Functions should never be ptrauth-called directly.
2648+
assert(!CB.getCalledFunction() && "invalid direct ptrauth call");
2649+
2650+
auto PAB = CB.getOperandBundle("ptrauth");
2651+
Value *Key = PAB->Inputs[0];
2652+
Value *Discriminator = PAB->Inputs[1];
2653+
2654+
Register DiscReg = getOrCreateVReg(*Discriminator);
2655+
PAI = CallLowering::PointerAuthInfo{DiscReg,
2656+
cast<ConstantInt>(Key)->getZExtValue()};
2657+
}
2658+
26452659
Register ConvergenceCtrlToken = 0;
26462660
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl)) {
26472661
const auto &Token = *Bundle->Inputs[0].get();
@@ -2652,7 +2666,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
26522666
// optimize into tail calls. Instead, we defer that to selection where a final
26532667
// scan is done to check if any instructions are calls.
26542668
bool Success = CLI->lowerCall(
2655-
MIRBuilder, CB, Res, Args, SwiftErrorVReg, ConvergenceCtrlToken,
2669+
MIRBuilder, CB, Res, Args, SwiftErrorVReg, PAI, ConvergenceCtrlToken,
26562670
[&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
26572671

26582672
// Check if we just inserted a tail call.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3307,12 +3307,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
33073307
const BasicBlock *EHPadBB = I.getSuccessor(1);
33083308
MachineBasicBlock *EHPadMBB = FuncInfo.MBBMap[EHPadBB];
33093309

3310-
// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
3310+
// Deopt and ptrauth bundles are lowered in helper functions, and we don't
33113311
// have to do anything here to lower funclet bundles.
33123312
assert(!I.hasOperandBundlesOtherThan(
33133313
{LLVMContext::OB_deopt, LLVMContext::OB_gc_transition,
33143314
LLVMContext::OB_gc_live, LLVMContext::OB_funclet,
3315-
LLVMContext::OB_cfguardtarget,
3315+
LLVMContext::OB_cfguardtarget, LLVMContext::OB_ptrauth,
33163316
LLVMContext::OB_clang_arc_attachedcall}) &&
33173317
"Cannot lower invokes with arbitrary operand bundles yet!");
33183318

@@ -3363,6 +3363,8 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
33633363
// intrinsic, and right now there are no plans to support other intrinsics
33643364
// with deopt state.
33653365
LowerCallSiteWithDeoptBundle(&I, getValue(Callee), EHPadBB);
3366+
} else if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
3367+
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), EHPadBB);
33663368
} else {
33673369
LowerCallTo(I, getValue(Callee), false, false, EHPadBB);
33683370
}
@@ -8531,9 +8533,9 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
85318533
}
85328534

85338535
void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
8534-
bool isTailCall,
8535-
bool isMustTailCall,
8536-
const BasicBlock *EHPadBB) {
8536+
bool isTailCall, bool isMustTailCall,
8537+
const BasicBlock *EHPadBB,
8538+
const TargetLowering::PtrAuthInfo *PAI) {
85378539
auto &DL = DAG.getDataLayout();
85388540
FunctionType *FTy = CB.getFunctionType();
85398541
Type *RetTy = CB.getType();
@@ -8640,6 +8642,15 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
86408642
CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
86418643
.setCFIType(CFIType)
86428644
.setConvergenceControlToken(ConvControlToken);
8645+
8646+
// Set the pointer authentication info if we have it.
8647+
if (PAI) {
8648+
if (!TLI.supportPtrAuthBundles())
8649+
report_fatal_error(
8650+
"This target doesn't support calls with ptrauth operand bundles.");
8651+
CLI.setPtrAuth(*PAI);
8652+
}
8653+
86438654
std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);
86448655

86458656
if (Result.first.getNode()) {
@@ -9185,6 +9196,11 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
91859196
}
91869197
}
91879198

9199+
if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
9200+
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), /*EHPadBB=*/nullptr);
9201+
return;
9202+
}
9203+
91889204
// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
91899205
// have to do anything here to lower funclet bundles.
91909206
// CFGuardTarget bundles are lowered in LowerCallTo.
@@ -9206,6 +9222,31 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
92069222
LowerCallTo(I, Callee, I.isTailCall(), I.isMustTailCall());
92079223
}
92089224

9225+
void SelectionDAGBuilder::LowerCallSiteWithPtrAuthBundle(
9226+
const CallBase &CB, const BasicBlock *EHPadBB) {
9227+
auto PAB = CB.getOperandBundle("ptrauth");
9228+
auto *CalleeV = CB.getCalledOperand();
9229+
9230+
// Gather the call ptrauth data from the operand bundle:
9231+
// [ i32 <key>, i64 <discriminator> ]
9232+
auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
9233+
Value *Discriminator = PAB->Inputs[1];
9234+
9235+
assert(Key->getType()->isIntegerTy(32) && "Invalid ptrauth key");
9236+
assert(Discriminator->getType()->isIntegerTy(64) &&
9237+
"Invalid ptrauth discriminator");
9238+
9239+
// Functions should never be ptrauth-called directly.
9240+
assert(!isa<Function>(CalleeV) && "invalid direct ptrauth call");
9241+
9242+
// Otherwise, do an authenticated indirect call.
9243+
TargetLowering::PtrAuthInfo PAI = {Key->getZExtValue(),
9244+
getValue(Discriminator)};
9245+
9246+
LowerCallTo(CB, getValue(CalleeV), CB.isTailCall(), CB.isMustTailCall(),
9247+
EHPadBB, &PAI);
9248+
}
9249+
92099250
namespace {
92109251

92119252
/// AsmOperandInfo - This contains information for each constraint that we are

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ class SelectionDAGBuilder {
406406
void CopyToExportRegsIfNeeded(const Value *V);
407407
void ExportFromCurrentBlock(const Value *V);
408408
void LowerCallTo(const CallBase &CB, SDValue Callee, bool IsTailCall,
409-
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr);
409+
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
410+
const TargetLowering::PtrAuthInfo *PAI = nullptr);
410411

411412
// Lower range metadata from 0 to N to assert zext to an integer of nearest
412413
// floor power of two.
@@ -490,6 +491,9 @@ class SelectionDAGBuilder {
490491
bool VarArgDisallowed,
491492
bool ForceVoidReturnTy);
492493

494+
void LowerCallSiteWithPtrAuthBundle(const CallBase &CB,
495+
const BasicBlock *EHPadBB);
496+
493497
/// Returns the type of FrameIndex and TargetFrameIndex nodes.
494498
MVT getFrameIndexTy() {
495499
return DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout());

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class AArch64AsmPrinter : public AsmPrinter {
125125

126126
void emitSled(const MachineInstr &MI, SledKind Kind);
127127

128+
// Emit the sequence for BLRA (authenticate + branch).
129+
void emitPtrauthBranch(const MachineInstr *MI);
130+
// Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
131+
unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
132+
unsigned &InstsEmitted);
133+
128134
/// tblgen'erated driver function for lowering simple MI->MC
129135
/// pseudo instructions.
130136
bool emitPseudoExpansionLowering(MCStreamer &OutStreamer,
@@ -1504,6 +1510,77 @@ void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
15041510
}
15051511
}
15061512

1513+
unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
1514+
unsigned AddrDisc,
1515+
unsigned &InstsEmitted) {
1516+
// So far we've used NoRegister in pseudos. Now we need real encodings.
1517+
if (AddrDisc == AArch64::NoRegister)
1518+
AddrDisc = AArch64::XZR;
1519+
1520+
// If there is no constant discriminator, there's no blend involved:
1521+
// just use the address discriminator register as-is (XZR or not).
1522+
if (!Disc)
1523+
return AddrDisc;
1524+
1525+
// If there's only a constant discriminator, MOV it into x17.
1526+
if (AddrDisc == AArch64::XZR) {
1527+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
1528+
.addReg(AArch64::X17)
1529+
.addImm(Disc)
1530+
.addImm(/*shift=*/0));
1531+
++InstsEmitted;
1532+
return AArch64::X17;
1533+
}
1534+
1535+
// If there are both, emit a blend into x17.
1536+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
1537+
.addReg(AArch64::X17)
1538+
.addReg(AArch64::XZR)
1539+
.addReg(AddrDisc)
1540+
.addImm(0));
1541+
++InstsEmitted;
1542+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
1543+
.addReg(AArch64::X17)
1544+
.addReg(AArch64::X17)
1545+
.addImm(Disc)
1546+
.addImm(/*shift=*/48));
1547+
++InstsEmitted;
1548+
return AArch64::X17;
1549+
}
1550+
1551+
void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
1552+
unsigned InstsEmitted = 0;
1553+
1554+
unsigned BrTarget = MI->getOperand(0).getReg();
1555+
auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();
1556+
uint64_t Disc = MI->getOperand(2).getImm();
1557+
unsigned AddrDisc = MI->getOperand(3).getReg();
1558+
1559+
// Compute discriminator into x17
1560+
assert(isUInt<16>(Disc));
1561+
unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, InstsEmitted);
1562+
bool IsZeroDisc = DiscReg == AArch64::XZR;
1563+
1564+
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
1565+
"Invalid auth call key");
1566+
1567+
unsigned Opc;
1568+
if (Key == AArch64PACKey::IA)
1569+
Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
1570+
else
1571+
Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;
1572+
1573+
MCInst BRInst;
1574+
BRInst.setOpcode(Opc);
1575+
BRInst.addOperand(MCOperand::createReg(BrTarget));
1576+
if (!IsZeroDisc)
1577+
BRInst.addOperand(MCOperand::createReg(DiscReg));
1578+
EmitToStreamer(*OutStreamer, BRInst);
1579+
++InstsEmitted;
1580+
1581+
assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
1582+
}
1583+
15071584
// Simple pseudo-instructions have their lowering (with expansion to real
15081585
// instructions) auto-generated.
15091586
#include "AArch64GenMCPseudoLowering.inc"
@@ -1639,9 +1716,61 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
16391716
return;
16401717
}
16411718

1719+
case AArch64::BLRA:
1720+
emitPtrauthBranch(MI);
1721+
return;
1722+
16421723
// Tail calls use pseudo instructions so they have the proper code-gen
16431724
// attributes (isCall, isReturn, etc.). We lower them to the real
16441725
// instruction here.
1726+
case AArch64::AUTH_TCRETURN:
1727+
case AArch64::AUTH_TCRETURN_BTI: {
1728+
const uint64_t Key = MI->getOperand(2).getImm();
1729+
assert(Key < 2 && "Unknown key kind for authenticating tail-call return");
1730+
const uint64_t Disc = MI->getOperand(3).getImm();
1731+
Register AddrDisc = MI->getOperand(4).getReg();
1732+
1733+
Register ScratchReg = MI->getOperand(0).getReg() == AArch64::X16
1734+
? AArch64::X17
1735+
: AArch64::X16;
1736+
1737+
unsigned DiscReg = AddrDisc;
1738+
if (Disc) {
1739+
assert(isUInt<16>(Disc) && "Integer discriminator is too wide");
1740+
1741+
if (AddrDisc != AArch64::NoRegister) {
1742+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
1743+
.addReg(ScratchReg)
1744+
.addReg(AArch64::XZR)
1745+
.addReg(AddrDisc)
1746+
.addImm(0));
1747+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
1748+
.addReg(ScratchReg)
1749+
.addReg(ScratchReg)
1750+
.addImm(Disc)
1751+
.addImm(/*shift=*/48));
1752+
} else {
1753+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
1754+
.addReg(ScratchReg)
1755+
.addImm(Disc)
1756+
.addImm(/*shift=*/0));
1757+
}
1758+
DiscReg = ScratchReg;
1759+
}
1760+
1761+
const bool isZero = DiscReg == AArch64::NoRegister;
1762+
const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ},
1763+
{AArch64::BRAB, AArch64::BRABZ}};
1764+
1765+
MCInst TmpInst;
1766+
TmpInst.setOpcode(Opcodes[Key][isZero]);
1767+
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
1768+
if (!isZero)
1769+
TmpInst.addOperand(MCOperand::createReg(DiscReg));
1770+
EmitToStreamer(*OutStreamer, TmpInst);
1771+
return;
1772+
}
1773+
16451774
case AArch64::TCRETURNri:
16461775
case AArch64::TCRETURNrix16x17:
16471776
case AArch64::TCRETURNrix17:

0 commit comments

Comments
 (0)