Skip to content

Commit 536ab53

Browse files
committed
[AArch64][PAC] Lower authenticated calls with ptrauth bundles.
This adds codegen support for the "ptrauth" operand bundles, which can be used to augment indirect calls with the equivalent of an `@llvm.ptrauth.auth` intrinsic call on the call target (possibly preceded by an `@llvm.ptrauth.blend` on the auth discriminator if applicable.) This allows the generation of combined authenticating calls on AArch64 (in the BLRA* PAuth instructions), while avoiding the raw just-authenticated function pointer from being exposed to attackers. This is done by threading a PtrAuthInfo descriptor through the call lowering infrastructure. Note that this also applies to the other forms of indirect calls, notably invokes, rvmarker, and tail calls. Tail-calls in particular bring some additional complexity, with the intersecting register constraints of BTI and PAC discriminator computation.
1 parent 75825f3 commit 536ab53

20 files changed

+1292
-59
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ class CallLowering {
9999
ArgInfo() = default;
100100
};
101101

102+
struct PointerAuthInfo {
103+
Register Discriminator;
104+
uint64_t Key;
105+
};
106+
102107
struct CallLoweringInfo {
103108
/// Calling convention to be used for the call.
104109
CallingConv::ID CallConv = CallingConv::C;
@@ -125,6 +130,8 @@ class CallLowering {
125130

126131
MDNode *KnownCallees = nullptr;
127132

133+
std::optional<PointerAuthInfo> PAI;
134+
128135
/// True if the call must be tail call optimized.
129136
bool IsMustTailCall = false;
130137

@@ -587,6 +594,7 @@ class CallLowering {
587594
bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
588595
ArrayRef<Register> ResRegs,
589596
ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
597+
std::optional<PointerAuthInfo> PAI,
590598
Register ConvergenceCtrlToken,
591599
std::function<unsigned()> GetCalleeReg) const;
592600

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4361,6 +4361,9 @@ class TargetLowering : public TargetLoweringBase {
43614361
/// Return true if the target supports kcfi operand bundles.
43624362
virtual bool supportKCFIBundles() const { return false; }
43634363

4364+
/// Return true if the target supports ptrauth operand bundles.
4365+
virtual bool supportPtrAuthBundles() const { return false; }
4366+
43644367
/// Perform necessary initialization to handle a subset of CSRs explicitly
43654368
/// via copies. This function is called at the beginning of instruction
43664369
/// selection.
@@ -4472,6 +4475,14 @@ class TargetLowering : public TargetLoweringBase {
44724475
llvm_unreachable("Not Implemented");
44734476
}
44744477

4478+
/// This structure contains the information necessary for lowering
4479+
/// pointer-authenticating indirect calls. It is equivalent to the "ptrauth"
4480+
/// operand bundle found on the call instruction, if any.
4481+
struct PtrAuthInfo {
4482+
uint64_t Key;
4483+
SDValue Discriminator;
4484+
};
4485+
44754486
/// This structure contains all information that is necessary for lowering
44764487
/// calls. It is passed to TLI::LowerCallTo when the SelectionDAG builder
44774488
/// needs to lower a call, and targets will see this struct in their LowerCall
@@ -4511,6 +4522,8 @@ class TargetLowering : public TargetLoweringBase {
45114522
const ConstantInt *CFIType = nullptr;
45124523
SDValue ConvergenceControlToken;
45134524

4525+
std::optional<PtrAuthInfo> PAI;
4526+
45144527
CallLoweringInfo(SelectionDAG &DAG)
45154528
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
45164529
DoesNotReturn(false), IsReturnValueUsed(true), IsConvergent(false),
@@ -4633,6 +4646,11 @@ class TargetLowering : public TargetLoweringBase {
46334646
return *this;
46344647
}
46354648

4649+
CallLoweringInfo &setPtrAuth(PtrAuthInfo Value) {
4650+
PAI = Value;
4651+
return *this;
4652+
}
4653+
46364654
CallLoweringInfo &setIsPostTypeLegalization(bool Value=true) {
46374655
IsPostTypeLegalization = Value;
46384656
return *this;

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
9292
ArrayRef<Register> ResRegs,
9393
ArrayRef<ArrayRef<Register>> ArgRegs,
9494
Register SwiftErrorVReg,
95+
std::optional<PointerAuthInfo> PAI,
9596
Register ConvergenceCtrlToken,
9697
std::function<unsigned()> GetCalleeReg) const {
9798
CallLoweringInfo Info;
@@ -188,6 +189,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
188189
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
189190
Info.CallConv = CallConv;
190191
Info.SwiftErrorVReg = SwiftErrorVReg;
192+
Info.PAI = PAI;
191193
Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
192194
Info.IsMustTailCall = CB.isMustTailCall();
193195
Info.IsTailCall = CanBeTailCalled;

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2642,6 +2642,20 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
26422642
}
26432643
}
26442644

2645+
std::optional<CallLowering::PointerAuthInfo> PAI;
2646+
if (CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
2647+
// Functions should never be ptrauth-called directly.
2648+
assert(!CB.getCalledFunction() && "invalid direct ptrauth call");
2649+
2650+
auto PAB = CB.getOperandBundle("ptrauth");
2651+
Value *Key = PAB->Inputs[0];
2652+
Value *Discriminator = PAB->Inputs[1];
2653+
2654+
Register DiscReg = getOrCreateVReg(*Discriminator);
2655+
PAI = CallLowering::PointerAuthInfo{DiscReg,
2656+
cast<ConstantInt>(Key)->getZExtValue()};
2657+
}
2658+
26452659
Register ConvergenceCtrlToken = 0;
26462660
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl)) {
26472661
const auto &Token = *Bundle->Inputs[0].get();
@@ -2652,7 +2666,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
26522666
// optimize into tail calls. Instead, we defer that to selection where a final
26532667
// scan is done to check if any instructions are calls.
26542668
bool Success = CLI->lowerCall(
2655-
MIRBuilder, CB, Res, Args, SwiftErrorVReg, ConvergenceCtrlToken,
2669+
MIRBuilder, CB, Res, Args, SwiftErrorVReg, PAI, ConvergenceCtrlToken,
26562670
[&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
26572671

26582672
// Check if we just inserted a tail call.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3307,12 +3307,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
33073307
const BasicBlock *EHPadBB = I.getSuccessor(1);
33083308
MachineBasicBlock *EHPadMBB = FuncInfo.MBBMap[EHPadBB];
33093309

3310-
// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
3310+
// Deopt and ptrauth bundles are lowered in helper functions, and we don't
33113311
// have to do anything here to lower funclet bundles.
33123312
assert(!I.hasOperandBundlesOtherThan(
33133313
{LLVMContext::OB_deopt, LLVMContext::OB_gc_transition,
33143314
LLVMContext::OB_gc_live, LLVMContext::OB_funclet,
3315-
LLVMContext::OB_cfguardtarget,
3315+
LLVMContext::OB_cfguardtarget, LLVMContext::OB_ptrauth,
33163316
LLVMContext::OB_clang_arc_attachedcall}) &&
33173317
"Cannot lower invokes with arbitrary operand bundles yet!");
33183318

@@ -3363,6 +3363,8 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
33633363
// intrinsic, and right now there are no plans to support other intrinsics
33643364
// with deopt state.
33653365
LowerCallSiteWithDeoptBundle(&I, getValue(Callee), EHPadBB);
3366+
} else if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
3367+
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), EHPadBB);
33663368
} else {
33673369
LowerCallTo(I, getValue(Callee), false, false, EHPadBB);
33683370
}
@@ -8531,9 +8533,9 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
85318533
}
85328534

85338535
void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
8534-
bool isTailCall,
8535-
bool isMustTailCall,
8536-
const BasicBlock *EHPadBB) {
8536+
bool isTailCall, bool isMustTailCall,
8537+
const BasicBlock *EHPadBB,
8538+
const TargetLowering::PtrAuthInfo *PAI) {
85378539
auto &DL = DAG.getDataLayout();
85388540
FunctionType *FTy = CB.getFunctionType();
85398541
Type *RetTy = CB.getType();
@@ -8640,6 +8642,15 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
86408642
CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
86418643
.setCFIType(CFIType)
86428644
.setConvergenceControlToken(ConvControlToken);
8645+
8646+
// Set the pointer authentication info if we have it.
8647+
if (PAI) {
8648+
if (!TLI.supportPtrAuthBundles())
8649+
report_fatal_error(
8650+
"This target doesn't support calls with ptrauth operand bundles.");
8651+
CLI.setPtrAuth(*PAI);
8652+
}
8653+
86438654
std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);
86448655

86458656
if (Result.first.getNode()) {
@@ -9185,6 +9196,11 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
91859196
}
91869197
}
91879198

9199+
if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
9200+
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), /*EHPadBB=*/nullptr);
9201+
return;
9202+
}
9203+
91889204
// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
91899205
// have to do anything here to lower funclet bundles.
91909206
// CFGuardTarget bundles are lowered in LowerCallTo.
@@ -9206,6 +9222,31 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
92069222
LowerCallTo(I, Callee, I.isTailCall(), I.isMustTailCall());
92079223
}
92089224

9225+
void SelectionDAGBuilder::LowerCallSiteWithPtrAuthBundle(
9226+
const CallBase &CB, const BasicBlock *EHPadBB) {
9227+
auto PAB = CB.getOperandBundle("ptrauth");
9228+
auto *CalleeV = CB.getCalledOperand();
9229+
9230+
// Gather the call ptrauth data from the operand bundle:
9231+
// [ i32 <key>, i64 <discriminator> ]
9232+
auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
9233+
Value *Discriminator = PAB->Inputs[1];
9234+
9235+
assert(Key->getType()->isIntegerTy(32) && "Invalid ptrauth key");
9236+
assert(Discriminator->getType()->isIntegerTy(64) &&
9237+
"Invalid ptrauth discriminator");
9238+
9239+
// Functions should never be ptrauth-called directly.
9240+
assert(!isa<Function>(CalleeV) && "invalid direct ptrauth call");
9241+
9242+
// Otherwise, do an authenticated indirect call.
9243+
TargetLowering::PtrAuthInfo PAI = {Key->getZExtValue(),
9244+
getValue(Discriminator)};
9245+
9246+
LowerCallTo(CB, getValue(CalleeV), CB.isTailCall(), CB.isMustTailCall(),
9247+
EHPadBB, &PAI);
9248+
}
9249+
92099250
namespace {
92109251

92119252
/// AsmOperandInfo - This contains information for each constraint that we are

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ class SelectionDAGBuilder {
406406
void CopyToExportRegsIfNeeded(const Value *V);
407407
void ExportFromCurrentBlock(const Value *V);
408408
void LowerCallTo(const CallBase &CB, SDValue Callee, bool IsTailCall,
409-
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr);
409+
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
410+
const TargetLowering::PtrAuthInfo *PAI = nullptr);
410411

411412
// Lower range metadata from 0 to N to assert zext to an integer of nearest
412413
// floor power of two.
@@ -490,6 +491,9 @@ class SelectionDAGBuilder {
490491
bool VarArgDisallowed,
491492
bool ForceVoidReturnTy);
492493

494+
void LowerCallSiteWithPtrAuthBundle(const CallBase &CB,
495+
const BasicBlock *EHPadBB);
496+
493497
/// Returns the type of FrameIndex and TargetFrameIndex nodes.
494498
MVT getFrameIndexTy() {
495499
return DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout());

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,8 @@ class AArch64AsmPrinter : public AsmPrinter {
134134

135135
void emitSled(const MachineInstr &MI, SledKind Kind);
136136

137+
// Emit the sequence for BLRA (authenticate + branch).
138+
void emitPtrauthBranch(const MachineInstr *MI);
137139
// Emit the sequence for AUT or AUTPAC.
138140
void emitPtrauthAuthResign(const MachineInstr *MI);
139141
// Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
@@ -1522,6 +1524,10 @@ void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
15221524
unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
15231525
unsigned AddrDisc,
15241526
unsigned &InstsEmitted) {
1527+
// So far we've used NoRegister in pseudos. Now we need real encodings.
1528+
if (AddrDisc == AArch64::NoRegister)
1529+
AddrDisc = AArch64::XZR;
1530+
15251531
// If there is no constant discriminator, there's no blend involved:
15261532
// just use the address discriminator register as-is (XZR or not).
15271533
if (!Disc)
@@ -1769,6 +1775,39 @@ void AArch64AsmPrinter::emitPtrauthAuthResign(const MachineInstr *MI) {
17691775
OutStreamer->emitLabel(EndSym);
17701776
}
17711777

1778+
void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
1779+
unsigned InstsEmitted = 0;
1780+
1781+
unsigned BrTarget = MI->getOperand(0).getReg();
1782+
auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();
1783+
uint64_t Disc = MI->getOperand(2).getImm();
1784+
unsigned AddrDisc = MI->getOperand(3).getReg();
1785+
1786+
// Compute discriminator into x17
1787+
assert(isUInt<16>(Disc));
1788+
unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, InstsEmitted);
1789+
bool IsZeroDisc = DiscReg == AArch64::XZR;
1790+
1791+
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
1792+
"Invalid auth call key");
1793+
1794+
unsigned Opc;
1795+
if (Key == AArch64PACKey::IA)
1796+
Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
1797+
else
1798+
Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;
1799+
1800+
MCInst BRInst;
1801+
BRInst.setOpcode(Opc);
1802+
BRInst.addOperand(MCOperand::createReg(BrTarget));
1803+
if (!IsZeroDisc)
1804+
BRInst.addOperand(MCOperand::createReg(DiscReg));
1805+
EmitToStreamer(*OutStreamer, BRInst);
1806+
++InstsEmitted;
1807+
1808+
assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
1809+
}
1810+
17721811
// Simple pseudo-instructions have their lowering (with expansion to real
17731812
// instructions) auto-generated.
17741813
#include "AArch64GenMCPseudoLowering.inc"
@@ -1909,9 +1948,60 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
19091948
emitPtrauthAuthResign(MI);
19101949
return;
19111950

1951+
case AArch64::BLRA:
1952+
emitPtrauthBranch(MI);
1953+
return;
1954+
19121955
// Tail calls use pseudo instructions so they have the proper code-gen
19131956
// attributes (isCall, isReturn, etc.). We lower them to the real
19141957
// instruction here.
1958+
case AArch64::AUTH_TCRETURN:
1959+
case AArch64::AUTH_TCRETURN_BTI: {
1960+
const uint64_t Key = MI->getOperand(2).getImm();
1961+
assert(Key < 2 && "Unknown key kind for authenticating tail-call return");
1962+
const uint64_t Disc = MI->getOperand(3).getImm();
1963+
Register AddrDisc = MI->getOperand(4).getReg();
1964+
1965+
Register ScratchReg = MI->getOperand(0).getReg() == AArch64::X16
1966+
? AArch64::X17
1967+
: AArch64::X16;
1968+
1969+
unsigned DiscReg = AddrDisc;
1970+
if (Disc) {
1971+
assert(isUInt<16>(Disc) && "Integer discriminator is too wide");
1972+
1973+
if (AddrDisc != AArch64::NoRegister) {
1974+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
1975+
.addReg(ScratchReg)
1976+
.addReg(AArch64::XZR)
1977+
.addReg(AddrDisc)
1978+
.addImm(0));
1979+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
1980+
.addReg(ScratchReg)
1981+
.addReg(ScratchReg)
1982+
.addImm(Disc)
1983+
.addImm(/*shift=*/48));
1984+
} else {
1985+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
1986+
.addReg(ScratchReg)
1987+
.addImm(Disc)
1988+
.addImm(/*shift=*/0));
1989+
}
1990+
DiscReg = ScratchReg;
1991+
}
1992+
1993+
const bool isZero = DiscReg == AArch64::NoRegister;
1994+
const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ},
1995+
{AArch64::BRAB, AArch64::BRABZ}};
1996+
1997+
MCInst TmpInst;
1998+
TmpInst.setOpcode(Opcodes[Key][isZero]);
1999+
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
2000+
if (!isZero)
2001+
TmpInst.addOperand(MCOperand::createReg(DiscReg));
2002+
EmitToStreamer(*OutStreamer, TmpInst);
2003+
return;
2004+
}
19152005
case AArch64::TCRETURNri:
19162006
case AArch64::TCRETURNrix16x17:
19172007
case AArch64::TCRETURNrix17:

0 commit comments

Comments
 (0)