Skip to content

Commit 1714447

Browse files
committed
Address review feedback.
- AArch64ExpandPseudos: generalize & use createCall - AArch64InstrInfo: describe x16/x17 usage in pseudo - AArch64InstrInfo: group SDNodes with others - test call for ELF as well - GlobalISel: rename PointerAuthInfo, reorder fields - various nits, auto, const
1 parent a742d68 commit 1714447

File tree

10 files changed

+245
-168
lines changed

10 files changed

+245
-168
lines changed

llvm/include/llvm/CodeGen/GlobalISel/CallLowering.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,9 +99,9 @@ class CallLowering {
9999
ArgInfo() = default;
100100
};
101101

102-
struct PointerAuthInfo {
103-
Register Discriminator;
102+
struct PtrAuthInfo {
104103
uint64_t Key;
104+
Register Discriminator;
105105
};
106106

107107
struct CallLoweringInfo {
@@ -130,7 +130,8 @@ class CallLowering {
130130

131131
MDNode *KnownCallees = nullptr;
132132

133-
std::optional<PointerAuthInfo> PAI;
133+
/// The auth-call information in the "ptrauth" bundle, if present.
134+
std::optional<PtrAuthInfo> PAI;
134135

135136
/// True if the call must be tail call optimized.
136137
bool IsMustTailCall = false;
@@ -594,8 +595,7 @@ class CallLowering {
594595
bool lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &Call,
595596
ArrayRef<Register> ResRegs,
596597
ArrayRef<ArrayRef<Register>> ArgRegs, Register SwiftErrorVReg,
597-
std::optional<PointerAuthInfo> PAI,
598-
Register ConvergenceCtrlToken,
598+
std::optional<PtrAuthInfo> PAI, Register ConvergenceCtrlToken,
599599
std::function<unsigned()> GetCalleeReg) const;
600600

601601
/// For targets which want to use big-endian can enable it with

llvm/lib/CodeGen/GlobalISel/CallLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ bool CallLowering::lowerCall(MachineIRBuilder &MIRBuilder, const CallBase &CB,
9292
ArrayRef<Register> ResRegs,
9393
ArrayRef<ArrayRef<Register>> ArgRegs,
9494
Register SwiftErrorVReg,
95-
std::optional<PointerAuthInfo> PAI,
95+
std::optional<PtrAuthInfo> PAI,
9696
Register ConvergenceCtrlToken,
9797
std::function<unsigned()> GetCalleeReg) const {
9898
CallLoweringInfo Info;

llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,18 +2642,18 @@ bool IRTranslator::translateCallBase(const CallBase &CB,
26422642
}
26432643
}
26442644

2645-
std::optional<CallLowering::PointerAuthInfo> PAI;
2645+
std::optional<CallLowering::PtrAuthInfo> PAI;
26462646
if (CB.countOperandBundlesOfType(LLVMContext::OB_ptrauth)) {
26472647
// Functions should never be ptrauth-called directly.
26482648
assert(!CB.getCalledFunction() && "invalid direct ptrauth call");
26492649

26502650
auto PAB = CB.getOperandBundle("ptrauth");
2651-
Value *Key = PAB->Inputs[0];
2652-
Value *Discriminator = PAB->Inputs[1];
2651+
const Value *Key = PAB->Inputs[0];
2652+
const Value *Discriminator = PAB->Inputs[1];
26532653

26542654
Register DiscReg = getOrCreateVReg(*Discriminator);
2655-
PAI = CallLowering::PointerAuthInfo{DiscReg,
2656-
cast<ConstantInt>(Key)->getZExtValue()};
2655+
PAI = CallLowering::PtrAuthInfo{cast<ConstantInt>(Key)->getZExtValue(),
2656+
DiscReg};
26572657
}
26582658

26592659
Register ConvergenceCtrlToken = 0;

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9225,12 +9225,12 @@ void SelectionDAGBuilder::visitCall(const CallInst &I) {
92259225
void SelectionDAGBuilder::LowerCallSiteWithPtrAuthBundle(
92269226
const CallBase &CB, const BasicBlock *EHPadBB) {
92279227
auto PAB = CB.getOperandBundle("ptrauth");
9228-
auto *CalleeV = CB.getCalledOperand();
9228+
const Value *CalleeV = CB.getCalledOperand();
92299229

92309230
// Gather the call ptrauth data from the operand bundle:
92319231
// [ i32 <key>, i64 <discriminator> ]
9232-
auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
9233-
Value *Discriminator = PAB->Inputs[1];
9232+
const auto *Key = cast<ConstantInt>(PAB->Inputs[0]);
9233+
const Value *Discriminator = PAB->Inputs[1];
92349234

92359235
assert(Key->getType()->isIntegerTy(32) && "Invalid ptrauth key");
92369236
assert(Discriminator->getType()->isIntegerTy(64) &&

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1550,7 +1550,6 @@ unsigned AArch64AsmPrinter::emitPtrauthDiscriminator(uint16_t Disc,
15501550

15511551
void AArch64AsmPrinter::emitPtrauthBranch(const MachineInstr *MI) {
15521552
unsigned InstsEmitted = 0;
1553-
15541553
unsigned BrTarget = MI->getOperand(0).getReg();
15551554

15561555
auto Key = (AArch64PACKey::ID)MI->getOperand(1).getImm();

llvm/lib/Target/AArch64/AArch64ExpandPseudoInsts.cpp

Lines changed: 36 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -774,26 +774,24 @@ bool AArch64ExpandPseudo::expandSVESpillFill(MachineBasicBlock &MBB,
774774
return true;
775775
}
776776

777-
// Create a call to CallTarget, copying over all the operands from *MBBI,
778-
// starting at the regmask.
779-
static MachineInstr *createCall(MachineBasicBlock &MBB,
780-
MachineBasicBlock::iterator MBBI,
781-
const AArch64InstrInfo *TII,
782-
MachineOperand &CallTarget,
783-
unsigned RegMaskStartIdx) {
784-
unsigned Opc = CallTarget.isGlobal() ? AArch64::BL : AArch64::BLR;
785-
MachineInstr *Call =
786-
BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(Opc)).getInstr();
787-
788-
assert((CallTarget.isGlobal() || CallTarget.isReg()) &&
789-
"invalid operand for regular call");
790-
Call->addOperand(CallTarget);
777+
// Create a call with the passed opcode and explicit operands, copying over all
778+
// the implicit operands from *MBBI, starting at the regmask.
779+
static MachineInstr *createCallWithOps(MachineBasicBlock &MBB,
780+
MachineBasicBlock::iterator MBBI,
781+
const AArch64InstrInfo *TII,
782+
unsigned Opcode,
783+
ArrayRef<MachineOperand> ExplicitOps,
784+
unsigned RegMaskStartIdx) {
785+
// Build the MI, with explicit operands first (including the call target).
786+
MachineInstr *Call = BuildMI(MBB, MBBI, MBBI->getDebugLoc(), TII->get(Opcode))
787+
.add(ExplicitOps)
788+
.getInstr();
791789

792790
// Register arguments are added during ISel, but cannot be added as explicit
793791
// operands of the branch as it expects to be B <target> which is only one
794792
// operand. Instead they are implicit operands used by the branch.
795793
while (!MBBI->getOperand(RegMaskStartIdx).isRegMask()) {
796-
auto MOP = MBBI->getOperand(RegMaskStartIdx);
794+
MachineOperand &MOP = MBBI->getOperand(RegMaskStartIdx);
797795
assert(MOP.isReg() && "can only add register operands");
798796
Call->addOperand(MachineOperand::CreateReg(
799797
MOP.getReg(), /*Def=*/false, /*Implicit=*/true, /*isKill=*/false,
@@ -807,6 +805,20 @@ static MachineInstr *createCall(MachineBasicBlock &MBB,
807805
return Call;
808806
}
809807

808+
// Create a call to CallTarget, copying over all the operands from *MBBI,
809+
// starting at the regmask.
810+
static MachineInstr *createCall(MachineBasicBlock &MBB,
811+
MachineBasicBlock::iterator MBBI,
812+
const AArch64InstrInfo *TII,
813+
MachineOperand &CallTarget,
814+
unsigned RegMaskStartIdx) {
815+
unsigned Opc = CallTarget.isGlobal() ? AArch64::BL : AArch64::BLR;
816+
817+
assert((CallTarget.isGlobal() || CallTarget.isReg()) &&
818+
"invalid operand for regular call");
819+
return createCallWithOps(MBB, MBBI, TII, Opc, CallTarget, RegMaskStartIdx);
820+
}
821+
810822
bool AArch64ExpandPseudo::expandCALL_RVMARKER(
811823
MachineBasicBlock &MBB, MachineBasicBlock::iterator MBBI) {
812824
// Expand CALL_RVMARKER pseudo to:
@@ -822,33 +834,19 @@ bool AArch64ExpandPseudo::expandCALL_RVMARKER(
822834

823835
if (MI.getOpcode() == AArch64::BLRA_RVMARKER) {
824836
// Pointer auth call.
837+
MachineOperand &CallTarget = MI.getOperand(1);
825838
MachineOperand &Key = MI.getOperand(2);
826-
assert((Key.getImm() == 0 || Key.getImm() == 1) &&
827-
"invalid key for ptrauth call");
828839
MachineOperand &IntDisc = MI.getOperand(3);
829840
MachineOperand &AddrDisc = MI.getOperand(4);
830841

831-
OriginalCall = BuildMI(MBB, MBBI, MI.getDebugLoc(), TII->get(AArch64::BLRA))
832-
.getInstr();
833-
OriginalCall->addOperand(MI.getOperand(1));
834-
OriginalCall->addOperand(Key);
835-
OriginalCall->addOperand(IntDisc);
836-
OriginalCall->addOperand(AddrDisc);
837-
838-
unsigned RegMaskStartIdx = 5;
839-
// Skip register arguments. Those are added during ISel, but are not
840-
// needed for the concrete branch.
841-
while (!MI.getOperand(RegMaskStartIdx).isRegMask()) {
842-
auto MOP = MI.getOperand(RegMaskStartIdx);
843-
assert(MOP.isReg() && "can only add register operands");
844-
OriginalCall->addOperand(MachineOperand::CreateReg(
845-
MOP.getReg(), /*Def=*/false, /*Implicit=*/true, /*isKill=*/false,
846-
/*isDead=*/false, /*isUndef=*/MOP.isUndef()));
847-
RegMaskStartIdx++;
848-
}
849-
for (const MachineOperand &MO :
850-
llvm::drop_begin(MI.operands(), RegMaskStartIdx))
851-
OriginalCall->addOperand(MO);
842+
assert((Key.getImm() == AArch64PACKey::IA ||
843+
Key.getImm() == AArch64PACKey::IB) &&
844+
"Invalid auth call key");
845+
846+
MachineOperand Ops[] = {CallTarget, Key, IntDisc, AddrDisc};
847+
848+
OriginalCall = createCallWithOps(MBB, MBBI, TII, AArch64::BLRA, Ops,
849+
/*RegMaskStartIdx=*/5);
852850
} else {
853851
assert(MI.getOpcode() == AArch64::BLR_RVMARKER && "unknown rvmarker MI");
854852
OriginalCall = createCall(MBB, MBBI, TII, MI.getOperand(1),

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ extractPtrauthBlendDiscriminators(SDValue Disc, SelectionDAG *DAG) {
349349
// If the constant discriminator (either the blend RHS, or the entire
350350
// discriminator value) isn't a 16-bit constant, bail out, and let the
351351
// discriminator be computed separately.
352-
auto *ConstDiscN = dyn_cast<ConstantSDNode>(ConstDisc);
352+
const auto *ConstDiscN = dyn_cast<ConstantSDNode>(ConstDisc);
353353
if (!ConstDiscN || !isUInt<16>(ConstDiscN->getZExtValue()))
354354
return std::make_tuple(DAG->getTargetConstant(0, DL, MVT::i64), Disc);
355355

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,30 @@ def AArch64call_arm64ec_to_x64 : SDNode<"AArch64ISD::CALL_ARM64EC_TO_X64",
662662
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
663663
SDNPVariadic]>;
664664

665+
def AArch64authcall : SDNode<"AArch64ISD::AUTH_CALL",
666+
SDTypeProfile<0, -1, [SDTCisPtrTy<0>,
667+
SDTCisVT<1, i32>,
668+
SDTCisVT<2, i64>,
669+
SDTCisVT<3, i64>]>,
670+
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
671+
SDNPVariadic]>;
672+
673+
def AArch64authtcret: SDNode<"AArch64ISD::AUTH_TC_RETURN",
674+
SDTypeProfile<0, 5, [SDTCisPtrTy<0>,
675+
SDTCisVT<2, i32>,
676+
SDTCisVT<3, i64>,
677+
SDTCisVT<4, i64>]>,
678+
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
679+
680+
def AArch64authcall_rvmarker : SDNode<"AArch64ISD::AUTH_CALL_RVMARKER",
681+
SDTypeProfile<0, -1, [SDTCisPtrTy<0>,
682+
SDTCisPtrTy<1>,
683+
SDTCisVT<2, i32>,
684+
SDTCisVT<3, i64>,
685+
SDTCisVT<4, i64>]>,
686+
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
687+
SDNPVariadic]>;
688+
665689
def AArch64brcond : SDNode<"AArch64ISD::BRCOND", SDT_AArch64Brcond,
666690
[SDNPHasChain]>;
667691
def AArch64cbz : SDNode<"AArch64ISD::CBZ", SDT_AArch64cbz,
@@ -1564,30 +1588,6 @@ let Predicates = [HasComplxNum, HasNEON] in {
15641588
(v4f32 (bitconvert (v2i64 (AArch64duplane64 (v2i64 V128:$Rm), VectorIndexD:$idx))))>;
15651589
}
15661590

1567-
def AArch64authcall : SDNode<"AArch64ISD::AUTH_CALL",
1568-
SDTypeProfile<0, -1, [SDTCisPtrTy<0>,
1569-
SDTCisVT<1, i32>,
1570-
SDTCisVT<2, i64>,
1571-
SDTCisVT<3, i64>]>,
1572-
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
1573-
SDNPVariadic]>;
1574-
1575-
def AArch64authtcret: SDNode<"AArch64ISD::AUTH_TC_RETURN",
1576-
SDTypeProfile<0, 5, [SDTCisPtrTy<0>,
1577-
SDTCisVT<2, i32>,
1578-
SDTCisVT<3, i64>,
1579-
SDTCisVT<4, i64>]>,
1580-
[SDNPHasChain, SDNPOptInGlue, SDNPVariadic]>;
1581-
1582-
def AArch64authcall_rvmarker : SDNode<"AArch64ISD::AUTH_CALL_RVMARKER",
1583-
SDTypeProfile<0, -1, [SDTCisPtrTy<0>,
1584-
SDTCisPtrTy<1>,
1585-
SDTCisVT<2, i32>,
1586-
SDTCisVT<3, i64>,
1587-
SDTCisVT<4, i64>]>,
1588-
[SDNPHasChain, SDNPOptInGlue, SDNPOutGlue,
1589-
SDNPVariadic]>;
1590-
15911591
// v8.3a Pointer Authentication
15921592
// These instructions inhabit part of the hint space and so can be used for
15931593
// armv8 targets. Keeping the old HINT mnemonic when compiling without PA is
@@ -1716,9 +1716,12 @@ let Predicates = [HasPAuth] in {
17161716
def BLRABZ : AuthOneOperand<0b001, 1, "blrabz">;
17171717
}
17181718

1719-
// BLRA pseudo, generalized version of BLRAA/BLRAB/Z.
1720-
// This directly manipulates x16/x17, which are the only registers the OS
1721-
// guarantees are safe to use for sensitive operations.
1719+
// BLRA pseudo, a generalized version of BLRAA/BLRAB/Z.
1720+
// This directly manipulates x16/x17 to materialize the discriminator.
1721+
// x16/x17 are generally used as the safe registers for sensitive ptrauth
1722+
// operations (such as raw address manipulation or discriminator
1723+
// materialization here), in part because they're handled in a safer way by
1724+
// the kernel, notably on Darwin.
17221725
def BLRA : Pseudo<(outs), (ins GPR64noip:$Rn, i32imm:$Key, i64imm:$Disc,
17231726
GPR64noip:$AddrDisc),
17241727
[(AArch64authcall GPR64noip:$Rn, timm:$Key, timm:$Disc,

llvm/lib/Target/AArch64/GISel/AArch64CallLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1015,7 +1015,7 @@ bool AArch64CallLowering::isEligibleForTailCallOptimization(
10151015

10161016
static unsigned getCallOpcode(const MachineFunction &CallerF, bool IsIndirect,
10171017
bool IsTailCall,
1018-
std::optional<CallLowering::PointerAuthInfo> &PAI,
1018+
std::optional<CallLowering::PtrAuthInfo> &PAI,
10191019
MachineRegisterInfo &MRI) {
10201020
const AArch64FunctionInfo *FuncInfo = CallerF.getInfo<AArch64FunctionInfo>();
10211021

0 commit comments

Comments
 (0)