Skip to content

Commit cc548ec

Browse files
[AArch64][PAC] Lower authenticated calls with ptrauth bundles. (llvm#85736)
This adds codegen support for the "ptrauth" operand bundles, which can be used to augment indirect calls with the equivalent of an `@llvm.ptrauth.auth` intrinsic call on the call target (possibly preceded by an `@llvm.ptrauth.blend` on the auth discriminator if applicable.) This allows the generation of combined authenticating calls on AArch64 (in the BLRA* PAuth instructions), while avoiding the raw just-authenticated function pointer from being exposed to attackers. This is done by threading a PtrAuthInfo descriptor through the call lowering infrastructure, eventually selecting a BLRA pseudo. The pseudo encapsulates the safe discriminator computation, which together with the real BLRA* call get emitted in late pseudo expansion in AsmPrinter. Note that this also applies to the other forms of indirect calls, notably invokes, rvmarker, and tail calls. Tail-calls in particular bring some additional complexity, with the intersecting register constraints of BTI and PAC discriminator computation. However this doesn't currently support PAuth_LR tail-call variants. This also adopts an x8+ allocation order for GPR64noip, matching GPR64.
1 parent 68fdc1c commit cc548ec

21 files changed

+1380
-76
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ class CallLowering {
9999
ArgInfo() = default;
100100
};
101101

102+
struct PtrAuthInfo {
103+
uint64_t Key;
104+
Register Discriminator;
105+
};
106+
102107
struct CallLoweringInfo {
103108
/// Calling convention to be used for the call.
104109
CallingConv::ID CallConv = CallingConv::C;
@@ -125,6 +130,9 @@ class CallLowering {
125130

126131
MDNode *KnownCallees = nullptr;
127132

133+
/// The auth-call information in the "ptrauth" bundle, if present.
134+
std::optional<PtrAuthInfo> PAI;
135+
128136
/// True if the call must be tail call optimized.
129137
bool IsMustTailCall = false;
130138

@@ -587,7 +595,7 @@ class CallLowering {
587595
bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
588596
ArrayRef<Register> ResRegs,
589597
ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
590-
Register ConvergenceCtrlToken,
598+
std::optional<PtrAuthInfo> PAI, Register ConvergenceCtrlToken,
591599
std::function<unsigned()> GetCalleeReg) const;
592600

593601
/// For targets which want to use big-endian can enable it with

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4367,6 +4367,9 @@ class TargetLowering : public TargetLoweringBase {
43674367
/// Return true if the target supports kcfi operand bundles.
43684368
virtual bool supportKCFIBundles() const { return false; }
43694369

4370+
/// Return true if the target supports ptrauth operand bundles.
4371+
virtual bool supportPtrAuthBundles() const { return false; }
4372+
43704373
/// Perform necessary initialization to handle a subset of CSRs explicitly
43714374
/// via copies. This function is called at the beginning of instruction
43724375
/// selection.
@@ -4478,6 +4481,14 @@ class TargetLowering : public TargetLoweringBase {
44784481
llvm_unreachable("Not Implemented");
44794482
}
44804483

4484+
/// This structure contains the information necessary for lowering
4485+
/// pointer-authenticating indirect calls. It is equivalent to the "ptrauth"
4486+
/// operand bundle found on the call instruction, if any.
4487+
struct PtrAuthInfo {
4488+
uint64_t Key;
4489+
SDValue Discriminator;
4490+
};
4491+
44814492
/// This structure contains all information that is necessary for lowering
44824493
/// calls. It is passed to TLI::LowerCallTo when the SelectionDAG builder
44834494
/// needs to lower a call, and targets will see this struct in their LowerCall
@@ -4517,6 +4528,8 @@ class TargetLowering : public TargetLoweringBase {
45174528
const ConstantInt *CFIType = nullptr;
45184529
SDValue ConvergenceControlToken;
45194530

4531+
std::optional<PtrAuthInfo> PAI;
4532+
45204533
CallLoweringInfo(SelectionDAG &DAG)
45214534
: RetSExt(false), RetZExt(false), IsVarArg(false), IsInReg(false),
45224535
DoesNotReturn(false), IsReturnValueUsed(true), IsConvergent(false),
@@ -4639,6 +4652,11 @@ class TargetLowering : public TargetLoweringBase {
46394652
return *this;
46404653
}
46414654

4655+
CallLoweringInfo &setPtrAuth(PtrAuthInfo Value) {
4656+
PAI = Value;
4657+
return *this;
4658+
}
4659+
46424660
CallLoweringInfo &setIsPostTypeLegalization(bool Value=true) {
46434661
IsPostTypeLegalization = Value;
46444662
return *this;

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
9292
ArrayRef<Register> ResRegs,
9393
ArrayRef<ArrayRef<Register>> ArgRegs,
9494
Register SwiftErrorVReg,
95+
std::optional<PtrAuthInfo> PAI,
9596
Register ConvergenceCtrlToken,
9697
std::function<unsigned()> GetCalleeReg) const {
9798
CallLoweringInfo Info;
@@ -188,6 +189,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
188189
Info.KnownCallees = CB.getMetadata(LLVMContext::MD_callees);
189190
Info.CallConv = CallConv;
190191
Info.SwiftErrorVReg = SwiftErrorVReg;
192+
Info.PAI = PAI;
191193
Info.ConvergenceCtrlToken = ConvergenceCtrlToken;
192194
Info.IsMustTailCall = CB.isMustTailCall();
193195
Info.IsTailCall = CanBeTailCalled;

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2644,6 +2644,20 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
26442644
}
26452645
}
26462646

2647+
std::optional<CallLowering::PtrAuthInfo> PAI;
2648+
if (CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
2649+
// Functions should never be ptrauth-called directly.
2650+
assert(!CB.getCalledFunction() && "invalid direct ptrauth call");
2651+
2652+
auto PAB = CB.getOperandBundle("ptrauth");
2653+
const Value *Key = PAB->Inputs[0];
2654+
const Value *Discriminator = PAB->Inputs[1];
2655+
2656+
Register DiscReg = getOrCreateVReg(*Discriminator);
2657+
PAI = CallLowering::PtrAuthInfo{cast<ConstantInt>(Key)->getZExtValue(),
2658+
DiscReg};
2659+
}
2660+
26472661
Register ConvergenceCtrlToken = 0;
26482662
if (auto Bundle = CB.getOperandBundle(LLVMContext::OB_convergencectrl)) {
26492663
const auto &Token = *Bundle->Inputs[0].get();
@@ -2654,7 +2668,7 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
26542668
// optimize into tail calls. Instead, we defer that to selection where a final
26552669
// scan is done to check if any instructions are calls.
26562670
bool Success = CLI->lowerCall(
2657-
MIRBuilder, CB, Res, Args, SwiftErrorVReg, ConvergenceCtrlToken,
2671+
MIRBuilder, CB, Res, Args, SwiftErrorVReg, PAI, ConvergenceCtrlToken,
26582672
[&]() { return getOrCreateVReg(*CB.getCalledOperand()); });
26592673

26602674
// Check if we just inserted a tail call.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3307,12 +3307,12 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
33073307
const BasicBlock *EHPadBB = I.getSuccessor(1);
33083308
MachineBasicBlock *EHPadMBB = FuncInfo.MBBMap[EHPadBB];
33093309

3310-
// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
3310+
// Deopt and ptrauth bundles are lowered in helper functions, and we don't
33113311
// have to do anything here to lower funclet bundles.
33123312
assert(!I.hasOperandBundlesOtherThan(
33133313
{LLVMContext::OB_deopt, LLVMContext::OB_gc_transition,
33143314
LLVMContext::OB_gc_live, LLVMContext::OB_funclet,
3315-
LLVMContext::OB_cfguardtarget,
3315+
LLVMContext::OB_cfguardtarget, LLVMContext::OB_ptrauth,
33163316
LLVMContext::OB_clang_arc_attachedcall}) &&
33173317
"Cannot lower invokes with arbitrary operand bundles yet!");
33183318

@@ -3363,6 +3363,8 @@ void SelectionDAGBuilder::visitInvoke(const InvokeInst &I) {
33633363
// intrinsic, and right now there are no plans to support other intrinsics
33643364
// with deopt state.
33653365
LowerCallSiteWithDeoptBundle(&I, getValue(Callee), EHPadBB);
3366+
} else if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
3367+
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), EHPadBB);
33663368
} else {
33673369
LowerCallTo(I, getValue(Callee), false, false, EHPadBB);
33683370
}
@@ -8598,9 +8600,9 @@ SelectionDAGBuilder::lowerInvokable(TargetLowering::CallLoweringInfo &CLI,
85988600
}
85998601

86008602
void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
8601-
bool isTailCall,
8602-
bool isMustTailCall,
8603-
const BasicBlock *EHPadBB) {
8603+
bool isTailCall, bool isMustTailCall,
8604+
const BasicBlock *EHPadBB,
8605+
const TargetLowering::PtrAuthInfo *PAI) {
86048606
auto &DL = DAG.getDataLayout();
86058607
FunctionType *FTy = CB.getFunctionType();
86068608
Type *RetTy = CB.getType();
@@ -8707,6 +8709,15 @@ void SelectionDAGBuilder::LowerCallTo(const CallBase &CB, SDValue Callee,
87078709
CB.countOperandBundlesOfType(LLVMContext::OB_preallocated) != 0)
87088710
.setCFIType(CFIType)
87098711
.setConvergenceControlToken(ConvControlToken);
8712+
8713+
// Set the pointer authentication info if we have it.
8714+
if (PAI) {
8715+
if (!TLI.supportPtrAuthBundles())
8716+
report_fatal_error(
8717+
"This target doesn't support calls with ptrauth operand bundles.");
8718+
CLI.setPtrAuth(*PAI);
8719+
}
8720+
87108721
std::pair<SDValue, SDValue> Result = lowerInvokable(CLI, EHPadBB);
87118722

87128723
if (Result.first.getNode()) {
@@ -9252,6 +9263,11 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
92529263
}
92539264
}
92549265

9266+
if (I.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
9267+
LowerCallSiteWithPtrAuthBundle(cast<CallBase>(I), /*EHPadBB=*/nullptr);
9268+
return;
9269+
}
9270+
92559271
// Deopt bundles are lowered in LowerCallSiteWithDeoptBundle, and we don't
92569272
// have to do anything here to lower funclet bundles.
92579273
// CFGuardTarget bundles are lowered in LowerCallTo.
@@ -9273,6 +9289,31 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
92739289
LowerCallTo(I, Callee, I.isTailCall(), I.isMustTailCall());
92749290
}
92759291

9292+
void SelectionDAGBuilder::LowerCallSiteWithPtrAuthBundle(
9293+
const CallBase &CB, const BasicBlock *EHPadBB) {
9294+
auto PAB = CB.getOperandBundle("ptrauth");
9295+
const Value *CalleeV = CB.getCalledOperand();
9296+
9297+
// Gather the call ptrauth data from the operand bundle:
9298+
// [ i32 <key>, i64 <discriminator> ]
9299+
const auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
9300+
const Value *Discriminator = PAB->Inputs[1];
9301+
9302+
assert(Key->getType()->isIntegerTy(32) && "Invalid ptrauth key");
9303+
assert(Discriminator->getType()->isIntegerTy(64) &&
9304+
"Invalid ptrauth discriminator");
9305+
9306+
// Functions should never be ptrauth-called directly.
9307+
assert(!isa<Function>(CalleeV) && "invalid direct ptrauth call");
9308+
9309+
// Otherwise, do an authenticated indirect call.
9310+
TargetLowering::PtrAuthInfo PAI = {Key->getZExtValue(),
9311+
getValue(Discriminator)};
9312+
9313+
LowerCallTo(CB, getValue(CalleeV), CB.isTailCall(), CB.isMustTailCall(),
9314+
EHPadBB, &PAI);
9315+
}
9316+
92769317
namespace {
92779318

92789319
/// AsmOperandInfo - This contains information for each constraint that we are

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,8 @@ class SelectionDAGBuilder {
406406
void CopyToExportRegsIfNeeded(const Value *V);
407407
void ExportFromCurrentBlock(const Value *V);
408408
void LowerCallTo(const CallBase &CB, SDValue Callee, bool IsTailCall,
409-
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr);
409+
bool IsMustTailCall, const BasicBlock *EHPadBB = nullptr,
410+
const TargetLowering::PtrAuthInfo *PAI = nullptr);
410411

411412
// Lower range metadata from 0 to N to assert zext to an integer of nearest
412413
// floor power of two.
@@ -490,6 +491,9 @@ class SelectionDAGBuilder {
490491
bool VarArgDisallowed,
491492
bool ForceVoidReturnTy);
492493

494+
void LowerCallSiteWithPtrAuthBundle(const CallBase &CB,
495+
const BasicBlock *EHPadBB);
496+
493497
/// Returns the type of FrameIndex and TargetFrameIndex nodes.
494498
MVT getFrameIndexTy() {
495499
return DAG.getTargetLoweringInfo().getFrameIndexTy(DAG.getDataLayout());

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,12 @@ class AArch64AsmPrinter : public AsmPrinter {
125125

126126
void emitSled(const MachineInstr &MI, SledKind Kind);
127127

128+
// Emit the sequence for BLRA (authenticate + branch).
129+
void emitPtrauthBranch(const MachineInstr *MI);
130+
// Emit the sequence to compute a discriminator into x17, or reuse AddrDisc.
131+
unsigned emitPtrauthDiscriminator(uint16_t Disc, unsigned AddrDisc,
132+
unsigned &InstsEmitted);
133+
128134
/// tblgen'erated driver function for lowering simple MI->MC
129135
/// pseudo instructions.
130136
bool emitPseudoExpansionLowering(MCStreamer &OutStreamer,
@@ -1497,6 +1503,78 @@ void AArch64AsmPrinter::emitFMov0(const MachineInstr &MI) {
14971503
}
14981504
}
14991505

1506+
unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
1507+
unsigned AddrDisc,
1508+
unsigned &InstsEmitted) {
1509+
// So far we've used NoRegister in pseudos. Now we need real encodings.
1510+
if (AddrDisc == AArch64::NoRegister)
1511+
AddrDisc = AArch64::XZR;
1512+
1513+
// If there is no constant discriminator, there's no blend involved:
1514+
// just use the address discriminator register as-is (XZR or not).
1515+
if (!Disc)
1516+
return AddrDisc;
1517+
1518+
// If there's only a constant discriminator, MOV it into x17.
1519+
if (AddrDisc == AArch64::XZR) {
1520+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
1521+
.addReg(AArch64::X17)
1522+
.addImm(Disc)
1523+
.addImm(/*shift=*/0));
1524+
++InstsEmitted;
1525+
return AArch64::X17;
1526+
}
1527+
1528+
// If there are both, emit a blend into x17.
1529+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
1530+
.addReg(AArch64::X17)
1531+
.addReg(AArch64::XZR)
1532+
.addReg(AddrDisc)
1533+
.addImm(0));
1534+
++InstsEmitted;
1535+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
1536+
.addReg(AArch64::X17)
1537+
.addReg(AArch64::X17)
1538+
.addImm(Disc)
1539+
.addImm(/*shift=*/48));
1540+
++InstsEmitted;
1541+
return AArch64::X17;
1542+
}
1543+
1544+
void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
1545+
unsigned InstsEmitted = 0;
1546+
unsigned BrTarget = MI->getOperand(0).getReg();
1547+
1548+
auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();
1549+
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
1550+
"Invalid auth call key");
1551+
1552+
uint64_t Disc = MI->getOperand(2).getImm();
1553+
assert(isUInt<16>(Disc));
1554+
1555+
unsigned AddrDisc = MI->getOperand(3).getReg();
1556+
1557+
// Compute discriminator into x17
1558+
unsigned DiscReg = emitPtrauthDiscriminator(Disc, AddrDisc, InstsEmitted);
1559+
bool IsZeroDisc = DiscReg == AArch64::XZR;
1560+
1561+
unsigned Opc;
1562+
if (Key == AArch64PACKey::IA)
1563+
Opc = IsZeroDisc ? AArch64::BLRAAZ : AArch64::BLRAA;
1564+
else
1565+
Opc = IsZeroDisc ? AArch64::BLRABZ : AArch64::BLRAB;
1566+
1567+
MCInst BRInst;
1568+
BRInst.setOpcode(Opc);
1569+
BRInst.addOperand(MCOperand::createReg(BrTarget));
1570+
if (!IsZeroDisc)
1571+
BRInst.addOperand(MCOperand::createReg(DiscReg));
1572+
EmitToStreamer(*OutStreamer, BRInst);
1573+
++InstsEmitted;
1574+
1575+
assert(STI->getInstrInfo()->getInstSizeInBytes(*MI) >= InstsEmitted * 4);
1576+
}
1577+
15001578
// Simple pseudo-instructions have their lowering (with expansion to real
15011579
// instructions) auto-generated.
15021580
#include "AArch64GenMCPseudoLowering.inc"
@@ -1632,9 +1710,63 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
16321710
return;
16331711
}
16341712

1713+
case AArch64::BLRA:
1714+
emitPtrauthBranch(MI);
1715+
return;
1716+
16351717
// Tail calls use pseudo instructions so they have the proper code-gen
16361718
// attributes (isCall, isReturn, etc.). We lower them to the real
16371719
// instruction here.
1720+
case AArch64::AUTH_TCRETURN:
1721+
case AArch64::AUTH_TCRETURN_BTI: {
1722+
const uint64_t Key = MI->getOperand(2).getImm();
1723+
assert((Key == AArch64PACKey::IA || Key == AArch64PACKey::IB) &&
1724+
"Invalid auth key for tail-call return");
1725+
1726+
const uint64_t Disc = MI->getOperand(3).getImm();
1727+
assert(isUInt<16>(Disc) && "Integer discriminator is too wide");
1728+
1729+
Register AddrDisc = MI->getOperand(4).getReg();
1730+
1731+
Register ScratchReg = MI->getOperand(0).getReg() == AArch64::X16
1732+
? AArch64::X17
1733+
: AArch64::X16;
1734+
1735+
unsigned DiscReg = AddrDisc;
1736+
if (Disc) {
1737+
if (AddrDisc != AArch64::NoRegister) {
1738+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::ORRXrs)
1739+
.addReg(ScratchReg)
1740+
.addReg(AArch64::XZR)
1741+
.addReg(AddrDisc)
1742+
.addImm(0));
1743+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVKXi)
1744+
.addReg(ScratchReg)
1745+
.addReg(ScratchReg)
1746+
.addImm(Disc)
1747+
.addImm(/*shift=*/48));
1748+
} else {
1749+
EmitToStreamer(*OutStreamer, MCInstBuilder(AArch64::MOVZXi)
1750+
.addReg(ScratchReg)
1751+
.addImm(Disc)
1752+
.addImm(/*shift=*/0));
1753+
}
1754+
DiscReg = ScratchReg;
1755+
}
1756+
1757+
const bool IsZero = DiscReg == AArch64::NoRegister;
1758+
const unsigned Opcodes[2][2] = {{AArch64::BRAA, AArch64::BRAAZ},
1759+
{AArch64::BRAB, AArch64::BRABZ}};
1760+
1761+
MCInst TmpInst;
1762+
TmpInst.setOpcode(Opcodes[Key][IsZero]);
1763+
TmpInst.addOperand(MCOperand::createReg(MI->getOperand(0).getReg()));
1764+
if (!IsZero)
1765+
TmpInst.addOperand(MCOperand::createReg(DiscReg));
1766+
EmitToStreamer(*OutStreamer, TmpInst);
1767+
return;
1768+
}
1769+
16381770
case AArch64::TCRETURNri:
16391771
case AArch64::TCRETURNrix16x17:
16401772
case AArch64::TCRETURNrix17:

0 commit comments

Comments
 (0)