Skip to content

Commit b26f615

Browse files
committed
[AArch64] Add CodeGen support for FEAT_CPA
CPA stands for Checked Pointer Arithmetic and is part of the 2023 MTE architecture extensions for A-profile. The new CPA instructions perform regular pointer arithmetic (such as base register + offset) but check for overflow in the most significant bits of the result, enhancing security by detecting address tampering. In this patch we intend to capture the semantics of pointer arithmetic when it is not folded into loads/stores, then generate the appropriate CPA instructions. In order to preserve pointer arithmetic semantics through the backend, we add the PTRADD SelectionDAG node type. The PTRADD node and respective visitPTRADD() function are adapted from the CHERI/Morello LLVM tree. Mode details about the CPA extension can be found at: - https://community.arm.com/arm-community-blogs/b/architectures-and-processors-blog/posts/arm-a-profile-architecture-developments-2023 - https://developer.arm.com/documentation/ddi0602/2023-09/ This PR follows llvm#79569.
1 parent d46812a commit b26f615

File tree

14 files changed

+1104
-22
lines changed

14 files changed

+1104
-22
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,6 +1452,10 @@ enum NodeType {
14521452
// Outputs: [rv], output chain, glue
14531453
PATCHPOINT,
14541454

1455+
// PTRADD represents pointer arithmetic semantics, for those targets which
1456+
// benefit from that information.
1457+
PTRADD,
1458+
14551459
// Vector Predication
14561460
#define BEGIN_REGISTER_VP_SDNODE(VPSDID, ...) VPSDID,
14571461
#include "llvm/IR/VPIntrinsics.def"

llvm/include/llvm/Target/TargetMachine.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,11 @@ class TargetMachine {
434434
function_ref<void(std::unique_ptr<Module> MPart)> ModuleCallback) {
435435
return false;
436436
}
437+
438+
/// True if target has some particular form of dealing with pointer arithmetic
439+
/// semantics. False if pointer arithmetic should not be preserved for passes
440+
/// such as instruction selection, and can fallback to regular arithmetic.
441+
virtual bool shouldPreservePtrArith(const Function &F) const { return false; }
437442
};
438443

439444
/// This class describes a target machine that is implemented with the LLVM

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def SDTOther : SDTypeProfile<1, 0, [SDTCisVT<0, OtherVT>]>; // for 'vt'.
109109
def SDTUNDEF : SDTypeProfile<1, 0, []>; // for 'undef'.
110110
def SDTUnaryOp : SDTypeProfile<1, 1, []>; // for bitconvert.
111111

112-
def SDTPtrAddOp : SDTypeProfile<1, 2, [ // ptradd
112+
def SDTPtrAddOp : SDTypeProfile<1, 2, [ // ptradd
113113
SDTCisSameAs<0, 1>, SDTCisInt<2>, SDTCisPtrTy<1>
114114
]>;
115115
def SDTIntBinOp : SDTypeProfile<1, 2, [ // add, and, or, xor, udiv, etc.
@@ -390,7 +390,7 @@ def tblockaddress: SDNode<"ISD::TargetBlockAddress", SDTPtrLeaf, [],
390390

391391
def add : SDNode<"ISD::ADD" , SDTIntBinOp ,
392392
[SDNPCommutative, SDNPAssociative]>;
393-
def ptradd : SDNode<"ISD::ADD" , SDTPtrAddOp, []>;
393+
def ptradd : SDNode<"ISD::PTRADD" , SDTPtrAddOp, []>;
394394
def sub : SDNode<"ISD::SUB" , SDTIntBinOp>;
395395
def mul : SDNode<"ISD::MUL" , SDTIntBinOp,
396396
[SDNPCommutative, SDNPAssociative]>;

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,9 @@ namespace {
416416
SDValue visitMERGE_VALUES(SDNode *N);
417417
SDValue visitADD(SDNode *N);
418418
SDValue visitADDLike(SDNode *N);
419-
SDValue visitADDLikeCommutative(SDValue N0, SDValue N1, SDNode *LocReference);
419+
SDValue visitADDLikeCommutative(SDValue N0, SDValue N1,
420+
SDNode *LocReference);
421+
SDValue visitPTRADD(SDNode *N);
420422
SDValue visitSUB(SDNode *N);
421423
SDValue visitADDSAT(SDNode *N);
422424
SDValue visitSUBSAT(SDNode *N);
@@ -1082,7 +1084,7 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
10821084
// (load/store (add, (add, x, y), offset2)) ->
10831085
// (load/store (add, (add, x, offset2), y)).
10841086

1085-
if (N0.getOpcode() != ISD::ADD)
1087+
if (N0.getOpcode() != ISD::ADD && N0.getOpcode() != ISD::PTRADD)
10861088
return false;
10871089

10881090
// Check for vscale addressing modes.
@@ -1833,6 +1835,7 @@ SDValue DAGCombiner::visit(SDNode *N) {
18331835
case ISD::TokenFactor: return visitTokenFactor(N);
18341836
case ISD::MERGE_VALUES: return visitMERGE_VALUES(N);
18351837
case ISD::ADD: return visitADD(N);
1838+
case ISD::PTRADD: return visitPTRADD(N);
18361839
case ISD::SUB: return visitSUB(N);
18371840
case ISD::SADDSAT:
18381841
case ISD::UADDSAT: return visitADDSAT(N);
@@ -2349,7 +2352,7 @@ static bool canFoldInAddressingMode(SDNode *N, SDNode *Use, SelectionDAG &DAG,
23492352
}
23502353

23512354
TargetLowering::AddrMode AM;
2352-
if (N->getOpcode() == ISD::ADD) {
2355+
if (N->getOpcode() == ISD::ADD || N->getOpcode() == ISD::PTRADD) {
23532356
AM.HasBaseReg = true;
23542357
ConstantSDNode *Offset = dyn_cast<ConstantSDNode>(N->getOperand(1));
23552358
if (Offset)
@@ -2578,6 +2581,98 @@ SDValue DAGCombiner::foldSubToAvg(SDNode *N, const SDLoc &DL) {
25782581
return SDValue();
25792582
}
25802583

2584+
/// Try to fold a pointer arithmetic node.
2585+
/// This needs to be done separately from normal addition, because pointer
2586+
/// addition is not commutative.
2587+
/// This function was adapted from DAGCombiner::visitPTRADD() from the Morello
2588+
/// project, which is based on CHERI.
2589+
SDValue DAGCombiner::visitPTRADD(SDNode *N) {
2590+
SDValue N0 = N->getOperand(0);
2591+
SDValue N1 = N->getOperand(1);
2592+
EVT PtrVT = N0.getValueType();
2593+
EVT IntVT = N1.getValueType();
2594+
SDLoc DL(N);
2595+
2596+
// fold (ptradd undef, y) -> undef
2597+
if (N0.isUndef())
2598+
return N0;
2599+
2600+
// fold (ptradd x, undef) -> undef
2601+
if (N1.isUndef())
2602+
return DAG.getUNDEF(PtrVT);
2603+
2604+
// fold (ptradd x, 0) -> x
2605+
if (isNullConstant(N1))
2606+
return N0;
2607+
2608+
if (N0.getOpcode() == ISD::PTRADD &&
2609+
!reassociationCanBreakAddressingModePattern(ISD::PTRADD, DL, N, N0, N1)) {
2610+
SDValue X = N0.getOperand(0);
2611+
SDValue Y = N0.getOperand(1);
2612+
SDValue Z = N1;
2613+
bool N0OneUse = N0.hasOneUse();
2614+
bool YIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Y);
2615+
bool ZIsConstant = DAG.isConstantIntBuildVectorOrConstantInt(Z);
2616+
2617+
// (ptradd (ptradd x, y), z) -> (ptradd (ptradd x, z), y) if:
2618+
// * (ptradd x, y) has one use; and
2619+
// * y is a constant; and
2620+
// * z is not a constant.
2621+
// Serves to expose constant y for subsequent folding.
2622+
if (N0OneUse && YIsConstant && !ZIsConstant) {
2623+
SDValue Add = DAG.getNode(ISD::PTRADD, DL, IntVT, {X, Z});
2624+
2625+
// Calling visit() can replace the Add node with ISD::DELETED_NODE if
2626+
// there aren't any users, so keep a handle around whilst we visit it.
2627+
HandleSDNode ADDHandle(Add);
2628+
2629+
SDValue VisitedAdd = visit(Add.getNode());
2630+
if (VisitedAdd) {
2631+
// If visit() returns the same node, it means the SDNode was RAUW'd, and
2632+
// therefore we have to load the new value to perform the checks whether
2633+
// the reassociation fold is profitable.
2634+
if (VisitedAdd.getNode() == Add.getNode())
2635+
Add = ADDHandle.getValue();
2636+
else
2637+
Add = VisitedAdd;
2638+
}
2639+
2640+
return DAG.getMemBasePlusOffset(Add, Y, DL, SDNodeFlags());
2641+
}
2642+
2643+
bool ZOneUse = Z.hasOneUse();
2644+
2645+
// (ptradd (ptradd x, y), z) -> (ptradd x, (add y, z)) if:
2646+
// * x is a null pointer; or
2647+
// * y is a constant and z has one use; or
2648+
// * y is a constant and (ptradd x, y) has one use; or
2649+
// * (ptradd x, y) and z have one use and z is not a constant.
2650+
if (isNullConstant(X) || (YIsConstant && ZOneUse) ||
2651+
(YIsConstant && N0OneUse) || (N0OneUse && ZOneUse && !ZIsConstant)) {
2652+
SDValue Add = DAG.getNode(ISD::ADD, DL, IntVT, {Y, Z});
2653+
2654+
// Calling visit() can replace the Add node with ISD::DELETED_NODE if
2655+
// there aren't any users, so keep a handle around whilst we visit it.
2656+
HandleSDNode ADDHandle(Add);
2657+
2658+
SDValue VisitedAdd = visit(Add.getNode());
2659+
if (VisitedAdd) {
2660+
// If visit() returns the same node, it means the SDNode was RAUW'd, and
2661+
// therefore we have to load the new value to perform the checks whether
2662+
// the reassociation fold is profitable.
2663+
if (VisitedAdd.getNode() == Add.getNode())
2664+
Add = ADDHandle.getValue();
2665+
else
2666+
Add = VisitedAdd;
2667+
}
2668+
2669+
return DAG.getMemBasePlusOffset(X, Add, DL, SDNodeFlags());
2670+
}
2671+
}
2672+
2673+
return SDValue();
2674+
}
2675+
25812676
/// Try to fold a 'not' shifted sign-bit with add/sub with constant operand into
25822677
/// a shift and add with a different constant.
25832678
static SDValue foldAddSubOfSignBit(SDNode *N, const SDLoc &DL,

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4069,8 +4069,14 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
40694069
else
40704070
Index = DAG.getNode(ISD::MUL, dl, Index.getValueType(), Index,
40714071
DAG.getConstant(EntrySize, dl, Index.getValueType()));
4072-
SDValue Addr = DAG.getNode(ISD::ADD, dl, Index.getValueType(),
4073-
Index, Table);
4072+
SDValue Addr;
4073+
if (!DAG.getTarget().shouldPreservePtrArith(
4074+
DAG.getMachineFunction().getFunction())) {
4075+
Addr = DAG.getNode(ISD::ADD, dl, Index.getValueType(), Index, Table);
4076+
} else {
4077+
// PTRADD always takes the pointer first, so the operands are commuted
4078+
Addr = DAG.getNode(ISD::PTRADD, dl, Index.getValueType(), Table, Index);
4079+
}
40744080

40754081
EVT MemVT = EVT::getIntegerVT(*DAG.getContext(), EntrySize * 8);
40764082
SDValue LD = DAG.getExtLoad(
@@ -4081,8 +4087,15 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
40814087
// For PIC, the sequence is:
40824088
// BRIND(load(Jumptable + index) + RelocBase)
40834089
// RelocBase can be JumpTable, GOT or some sort of global base.
4084-
Addr = DAG.getNode(ISD::ADD, dl, PTy, Addr,
4085-
TLI.getPICJumpTableRelocBase(Table, DAG));
4090+
if (!DAG.getTarget().shouldPreservePtrArith(
4091+
DAG.getMachineFunction().getFunction())) {
4092+
Addr = DAG.getNode(ISD::ADD, dl, PTy, Addr,
4093+
TLI.getPICJumpTableRelocBase(Table, DAG));
4094+
} else {
4095+
// PTRADD always takes the pointer first, so the operands are commuted
4096+
Addr = DAG.getNode(ISD::PTRADD, dl, PTy,
4097+
TLI.getPICJumpTableRelocBase(Table, DAG), Addr);
4098+
}
40864099
}
40874100

40884101
Tmp1 = TLI.expandIndirectJTBranch(dl, LD.getValue(1), Addr, JTI, DAG);

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5387,7 +5387,8 @@ bool SelectionDAG::isADDLike(SDValue Op, bool NoWrap) const {
53875387

53885388
bool SelectionDAG::isBaseWithConstantOffset(SDValue Op) const {
53895389
return Op.getNumOperands() == 2 && isa<ConstantSDNode>(Op.getOperand(1)) &&
5390-
(Op.getOpcode() == ISD::ADD || isADDLike(Op));
5390+
(Op.getOpcode() == ISD::ADD || Op.getOpcode() == ISD::PTRADD ||
5391+
isADDLike(Op));
53915392
}
53925393

53935394
bool SelectionDAG::isKnownNeverNaN(SDValue Op, bool SNaN, unsigned Depth) const {
@@ -7785,7 +7786,12 @@ SDValue SelectionDAG::getMemBasePlusOffset(SDValue Ptr, SDValue Offset,
77857786
const SDNodeFlags Flags) {
77867787
assert(Offset.getValueType().isInteger());
77877788
EVT BasePtrVT = Ptr.getValueType();
7788-
return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags);
7789+
if (!this->getTarget().shouldPreservePtrArith(
7790+
this->getMachineFunction().getFunction())) {
7791+
return getNode(ISD::ADD, DL, BasePtrVT, Ptr, Offset, Flags);
7792+
} else {
7793+
return getNode(ISD::PTRADD, DL, BasePtrVT, Ptr, Offset, Flags);
7794+
}
77897795
}
77907796

77917797
/// Returns true if memcpy source is constant data.

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4293,6 +4293,12 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
42934293
SDLoc dl = getCurSDLoc();
42944294
auto &TLI = DAG.getTargetLoweringInfo();
42954295
GEPNoWrapFlags NW = cast<GEPOperator>(I).getNoWrapFlags();
4296+
unsigned int AddOpcode = ISD::PTRADD;
4297+
4298+
if (!DAG.getTarget().shouldPreservePtrArith(
4299+
DAG.getMachineFunction().getFunction())) {
4300+
AddOpcode = ISD::ADD;
4301+
}
42964302

42974303
// Normalize Vector GEP - all scalar operands should be converted to the
42984304
// splat vector.
@@ -4324,7 +4330,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
43244330
(int64_t(Offset) >= 0 && NW.hasNoUnsignedSignedWrap()))
43254331
Flags.setNoUnsignedWrap(true);
43264332

4327-
N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N,
4333+
N = DAG.getNode(AddOpcode, dl, N.getValueType(), N,
43284334
DAG.getConstant(Offset, dl, N.getValueType()), Flags);
43294335
}
43304336
} else {
@@ -4368,7 +4374,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
43684374

43694375
OffsVal = DAG.getSExtOrTrunc(OffsVal, dl, N.getValueType());
43704376

4371-
N = DAG.getNode(ISD::ADD, dl, N.getValueType(), N, OffsVal, Flags);
4377+
N = DAG.getNode(AddOpcode, dl, N.getValueType(), N, OffsVal, Flags);
43724378
continue;
43734379
}
43744380

@@ -4411,8 +4417,7 @@ void SelectionDAGBuilder::visitGetElementPtr(const User &I) {
44114417
}
44124418
}
44134419

4414-
N = DAG.getNode(ISD::ADD, dl,
4415-
N.getValueType(), N, IdxN);
4420+
N = DAG.getNode(AddOpcode, dl, N.getValueType(), N, IdxN);
44164421
}
44174422
}
44184423

@@ -4473,8 +4478,15 @@ void SelectionDAGBuilder::visitAlloca(const AllocaInst &I) {
44734478
// an address inside an alloca.
44744479
SDNodeFlags Flags;
44754480
Flags.setNoUnsignedWrap(true);
4476-
AllocSize = DAG.getNode(ISD::ADD, dl, AllocSize.getValueType(), AllocSize,
4477-
DAG.getConstant(StackAlignMask, dl, IntPtr), Flags);
4481+
if (DAG.getTarget().shouldPreservePtrArith(
4482+
DAG.getMachineFunction().getFunction())) {
4483+
AllocSize = DAG.getNode(ISD::PTRADD, dl, AllocSize.getValueType(),
4484+
DAG.getConstant(StackAlignMask, dl, IntPtr),
4485+
AllocSize, Flags);
4486+
} else {
4487+
AllocSize = DAG.getNode(ISD::ADD, dl, AllocSize.getValueType(), AllocSize,
4488+
DAG.getConstant(StackAlignMask, dl, IntPtr), Flags);
4489+
}
44784490

44794491
// Mask out the low bits for alignment purposes.
44804492
AllocSize = DAG.getNode(ISD::AND, dl, AllocSize.getValueType(), AllocSize,
@@ -9071,8 +9083,13 @@ bool SelectionDAGBuilder::visitMemPCpyCall(const CallInst &I) {
90719083
Size = DAG.getSExtOrTrunc(Size, sdl, Dst.getValueType());
90729084

90739085
// Adjust return pointer to point just past the last dst byte.
9074-
SDValue DstPlusSize = DAG.getNode(ISD::ADD, sdl, Dst.getValueType(),
9075-
Dst, Size);
9086+
unsigned int AddOpcode = ISD::PTRADD;
9087+
if (!DAG.getTarget().shouldPreservePtrArith(
9088+
DAG.getMachineFunction().getFunction())) {
9089+
AddOpcode = ISD::ADD;
9090+
}
9091+
SDValue DstPlusSize =
9092+
DAG.getNode(AddOpcode, sdl, Dst.getValueType(), Dst, Size);
90769093
setValue(&I, DstPlusSize);
90779094
return true;
90789095
}
@@ -11169,9 +11186,14 @@ TargetLowering::LowerCallTo(TargetLowering::CallLoweringInfo &CLI) const {
1116911186
MachineFunction &MF = CLI.DAG.getMachineFunction();
1117011187
Align HiddenSRetAlign = MF.getFrameInfo().getObjectAlign(DemoteStackIdx);
1117111188
for (unsigned i = 0; i < NumValues; ++i) {
11172-
SDValue Add = CLI.DAG.getNode(ISD::ADD, CLI.DL, PtrVT, DemoteStackSlot,
11173-
CLI.DAG.getConstant(Offsets[i], CLI.DL,
11174-
PtrVT), Flags);
11189+
unsigned int AddOpcode = ISD::PTRADD;
11190+
if (!CLI.DAG.getTarget().shouldPreservePtrArith(
11191+
CLI.DAG.getMachineFunction().getFunction())) {
11192+
AddOpcode = ISD::ADD;
11193+
}
11194+
SDValue Add = CLI.DAG.getNode(
11195+
AddOpcode, CLI.DL, PtrVT, DemoteStackSlot,
11196+
CLI.DAG.getConstant(Offsets[i], CLI.DL, PtrVT), Flags);
1117511197
SDValue L = CLI.DAG.getLoad(
1117611198
RetTys[i], CLI.DL, CLI.Chain, Add,
1117711199
MachinePointerInfo::getFixedStack(CLI.DAG.getMachineFunction(),

llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,7 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
256256

257257
// Binary operators
258258
case ISD::ADD: return "add";
259+
case ISD::PTRADD: return "ptradd";
259260
case ISD::SUB: return "sub";
260261
case ISD::MUL: return "mul";
261262
case ISD::MULHU: return "mulhu";

llvm/lib/Target/AArch64/AArch64InstrInfo.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10223,6 +10223,26 @@ let Predicates = [HasCPA] in {
1022310223
// Scalar multiply-add/subtract
1022410224
def MADDPT : MulAccumCPA<0, "maddpt">;
1022510225
def MSUBPT : MulAccumCPA<1, "msubpt">;
10226+
10227+
// Rules to use CPA instructions in pointer arithmetic patterns which are not
10228+
// folded into loads/stores. The AddedComplexity serves to help supersede
10229+
// other simpler (non-CPA) patterns and make sure CPA is used instead.
10230+
let AddedComplexity = 20 in {
10231+
def : Pat<(ptradd GPR64sp:$Rn, GPR64sp:$Rm),
10232+
(ADDPT_shift GPR64sp:$Rn, GPR64sp:$Rm, (i32 0))>;
10233+
def : Pat<(ptradd GPR64sp:$Rn, (shl GPR64sp:$Rm, (i64 imm0_7:$imm))),
10234+
(ADDPT_shift GPR64sp:$Rn, GPR64sp:$Rm,
10235+
(i32 (trunc_imm imm0_7:$imm)))>;
10236+
def : Pat<(ptradd GPR64sp:$Rn, (ineg GPR64sp:$Rm)),
10237+
(SUBPT_shift GPR64sp:$Rn, GPR64sp:$Rm, (i32 0))>;
10238+
def : Pat<(ptradd GPR64sp:$Rn, (ineg (shl GPR64sp:$Rm, (i64 imm0_7:$imm)))),
10239+
(SUBPT_shift GPR64sp:$Rn, GPR64sp:$Rm,
10240+
(i32 (trunc_imm imm0_7:$imm)))>;
10241+
def : Pat<(ptradd GPR64:$Ra, (mul GPR64:$Rn, GPR64:$Rm)),
10242+
(MADDPT GPR64:$Rn, GPR64:$Rm, GPR64:$Ra)>;
10243+
def : Pat<(ptradd GPR64:$Ra, (mul GPR64:$Rn, (ineg GPR64:$Rm))),
10244+
(MSUBPT GPR64:$Rn, GPR64:$Rm, GPR64:$Ra)>;
10245+
}
1022610246
}
1022710247

1022810248
def round_v4fp32_to_v4bf16 :

llvm/lib/Target/AArch64/AArch64TargetMachine.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -920,3 +920,7 @@ bool AArch64TargetMachine::parseMachineFunctionInfo(
920920
MF.getInfo<AArch64FunctionInfo>()->initializeBaseYamlFields(YamlMFI);
921921
return false;
922922
}
923+
924+
bool AArch64TargetMachine::shouldPreservePtrArith(const Function &F) const {
925+
return getSubtargetImpl(F)->hasCPA();
926+
}

llvm/lib/Target/AArch64/AArch64TargetMachine.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class AArch64TargetMachine : public LLVMTargetMachine {
6969
return true;
7070
}
7171

72+
/// In AArch64, true if FEAT_CPA is present. Allows pointer arithmetic
73+
/// semantics to be preserved for instruction selection.
74+
bool shouldPreservePtrArith(const Function &F) const override;
75+
7276
private:
7377
bool isLittle;
7478
};

llvm/lib/Target/AArch64/GISel/AArch64InstructionSelector.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,6 +2091,10 @@ bool AArch64InstructionSelector::preISelLower(MachineInstr &I) {
20912091
return Changed;
20922092
}
20932093
case TargetOpcode::G_PTR_ADD:
2094+
// If Checked Pointer Arithmetic (FEAT_CPA) is present, preserve the pointer
2095+
// arithmetic semantics instead of falling back to regular arithmetic.
2096+
if (TM.shouldPreservePtrArith(MF.getFunction()))
2097+
return false;
20942098
return convertPtrAddToAdd(I, MRI);
20952099
case TargetOpcode::G_LOAD: {
20962100
// For scalar loads of pointers, we try to convert the dest type from p0

0 commit comments

Comments
 (0)