Skip to content

Commit 0fe0968

Browse files
authored
[AArch64][FEAT_CMPBR] Codegen for Armv9.6-a compare-and-branch (llvm#116465)
This patch adds codegen for all Arm9.6-a compare-and-branch instructions, that operate on full w or x registers. The instruction variants operating on half-words (cbh) and bytes (cbb) are added in a subsequent patch. Since CB doesn't use standard 4-bit Arm condition codes but a reduced set of conditions, encoded in 3 bits, some conditions are expressed by modifying operands, namely incrementing or decrementing immediate operands and swapping register operands. To invert a CB instruction it's therefore not enough to just modify the condition code which doesn't play particularly well with how the backend is currently organized. We therefore introduce a number of pseudos which operate on the standard 4-bit condition codes and lower them late during codegen.
1 parent 8363b0a commit 0fe0968

14 files changed

+3134
-2
lines changed

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ class AArch64AsmPrinter : public AsmPrinter {
208208
void emitAttributes(unsigned Flags, uint64_t PAuthABIPlatform,
209209
uint64_t PAuthABIVersion, AArch64TargetStreamer *TS);
210210

211+
// Emit expansion of Compare-and-branch pseudo instructions
212+
void emitCBPseudoExpansion(const MachineInstr *MI);
213+
211214
void EmitToStreamer(MCStreamer &S, const MCInst &Inst);
212215
void EmitToStreamer(const MCInst &Inst) {
213216
EmitToStreamer(*OutStreamer, Inst);
@@ -2589,6 +2592,124 @@ AArch64AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) {
25892592
return BAE;
25902593
}
25912594

2595+
void AArch64AsmPrinter::emitCBPseudoExpansion(const MachineInstr *MI) {
2596+
bool IsImm = false;
2597+
bool Is32Bit = false;
2598+
2599+
switch (MI->getOpcode()) {
2600+
default:
2601+
llvm_unreachable("This is not a CB pseudo instruction");
2602+
case AArch64::CBWPrr:
2603+
Is32Bit = true;
2604+
break;
2605+
case AArch64::CBXPrr:
2606+
Is32Bit = false;
2607+
break;
2608+
case AArch64::CBWPri:
2609+
IsImm = true;
2610+
Is32Bit = true;
2611+
break;
2612+
case AArch64::CBXPri:
2613+
IsImm = true;
2614+
break;
2615+
}
2616+
2617+
AArch64CC::CondCode CC =
2618+
static_cast<AArch64CC::CondCode>(MI->getOperand(0).getImm());
2619+
bool NeedsRegSwap = false;
2620+
bool NeedsImmDec = false;
2621+
bool NeedsImmInc = false;
2622+
2623+
// Decide if we need to either swap register operands or increment/decrement
2624+
// immediate operands
2625+
unsigned MCOpC;
2626+
switch (CC) {
2627+
default:
2628+
llvm_unreachable("Invalid CB condition code");
2629+
case AArch64CC::EQ:
2630+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBEQWri : AArch64::CBEQXri)
2631+
: (Is32Bit ? AArch64::CBEQWrr : AArch64::CBEQXrr);
2632+
break;
2633+
case AArch64CC::NE:
2634+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBNEWri : AArch64::CBNEXri)
2635+
: (Is32Bit ? AArch64::CBNEWrr : AArch64::CBNEXrr);
2636+
break;
2637+
case AArch64CC::HS:
2638+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri)
2639+
: (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr);
2640+
NeedsImmDec = IsImm;
2641+
break;
2642+
case AArch64CC::LO:
2643+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri)
2644+
: (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr);
2645+
NeedsRegSwap = !IsImm;
2646+
break;
2647+
case AArch64CC::HI:
2648+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri)
2649+
: (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr);
2650+
break;
2651+
case AArch64CC::LS:
2652+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri)
2653+
: (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr);
2654+
NeedsRegSwap = !IsImm;
2655+
NeedsImmInc = IsImm;
2656+
break;
2657+
case AArch64CC::GE:
2658+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri)
2659+
: (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr);
2660+
NeedsImmDec = IsImm;
2661+
break;
2662+
case AArch64CC::LT:
2663+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri)
2664+
: (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr);
2665+
NeedsRegSwap = !IsImm;
2666+
break;
2667+
case AArch64CC::GT:
2668+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri)
2669+
: (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr);
2670+
break;
2671+
case AArch64CC::LE:
2672+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri)
2673+
: (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr);
2674+
NeedsRegSwap = !IsImm;
2675+
NeedsImmInc = IsImm;
2676+
break;
2677+
}
2678+
2679+
MCInst Inst;
2680+
Inst.setOpcode(MCOpC);
2681+
2682+
MCOperand Lhs, Rhs, Trgt;
2683+
lowerOperand(MI->getOperand(1), Lhs);
2684+
lowerOperand(MI->getOperand(2), Rhs);
2685+
lowerOperand(MI->getOperand(3), Trgt);
2686+
2687+
// Now swap, increment or decrement
2688+
if (NeedsRegSwap) {
2689+
assert(Lhs.isReg() && "Expected register operand for CB");
2690+
assert(Rhs.isReg() && "Expected register operand for CB");
2691+
Inst.addOperand(Rhs);
2692+
Inst.addOperand(Lhs);
2693+
} else if (NeedsImmDec) {
2694+
Rhs.setImm(Rhs.getImm() - 1);
2695+
Inst.addOperand(Lhs);
2696+
Inst.addOperand(Rhs);
2697+
} else if (NeedsImmInc) {
2698+
Rhs.setImm(Rhs.getImm() + 1);
2699+
Inst.addOperand(Lhs);
2700+
Inst.addOperand(Rhs);
2701+
} else {
2702+
Inst.addOperand(Lhs);
2703+
Inst.addOperand(Rhs);
2704+
}
2705+
2706+
assert((!IsImm || (Rhs.getImm() >= 0 && Rhs.getImm() < 64)) &&
2707+
"CB immediate operand out-of-bounds");
2708+
2709+
Inst.addOperand(Trgt);
2710+
EmitToStreamer(*OutStreamer, Inst);
2711+
}
2712+
25922713
// Simple pseudo-instructions have their lowering (with expansion to real
25932714
// instructions) auto-generated.
25942715
#include "AArch64GenMCPseudoLowering.inc"
@@ -3155,13 +3276,20 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
31553276
return;
31563277

31573278
case AArch64::BLR:
3158-
case AArch64::BR:
3279+
case AArch64::BR: {
31593280
recordIfImportCall(MI);
31603281
MCInst TmpInst;
31613282
MCInstLowering.Lower(MI, TmpInst);
31623283
EmitToStreamer(*OutStreamer, TmpInst);
31633284
return;
31643285
}
3286+
case AArch64::CBWPri:
3287+
case AArch64::CBXPri:
3288+
case AArch64::CBWPrr:
3289+
case AArch64::CBXPrr:
3290+
emitCBPseudoExpansion(MI);
3291+
return;
3292+
}
31653293

31663294
// Finally, do the automated lowerings for everything else.
31673295
MCInst TmpInst;

llvm/lib/Target/AArch64/AArch64ISelDAGToDAG.cpp

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,8 @@ class AArch64DAGToDAGISel : public SelectionDAGISel {
507507

508508
bool SelectAllActivePredicate(SDValue N);
509509
bool SelectAnyPredicate(SDValue N);
510+
511+
bool SelectCmpBranchUImm6Operand(SDNode *P, SDValue N, SDValue &Imm);
510512
};
511513

512514
class AArch64DAGToDAGISelLegacy : public SelectionDAGISelLegacy {
@@ -7489,3 +7491,52 @@ bool AArch64DAGToDAGISel::SelectSMETileSlice(SDValue N, unsigned MaxSize,
74897491
Offset = CurDAG->getTargetConstant(0, SDLoc(N), MVT::i64);
74907492
return true;
74917493
}
7494+
7495+
bool AArch64DAGToDAGISel::SelectCmpBranchUImm6Operand(SDNode *P, SDValue N,
7496+
SDValue &Imm) {
7497+
AArch64CC::CondCode CC =
7498+
static_cast<AArch64CC::CondCode>(P->getConstantOperandVal(1));
7499+
if (auto *CN = dyn_cast<ConstantSDNode>(N)) {
7500+
// Check conservatively if the immediate fits the valid range [0, 64).
7501+
// Immediate variants for GE and HS definitely need to be decremented
7502+
// when lowering the pseudos later, so an immediate of 1 would become 0.
7503+
// For the inverse conditions LT and LO we don't know for sure if they
7504+
// will need a decrement but should the decision be made to reverse the
7505+
// branch condition, we again end up with the need to decrement.
7506+
// The same argument holds for LE, LS, GT and HI and possibly
7507+
// incremented immediates. This can lead to slightly less optimal
7508+
// codegen, e.g. we never codegen the legal case
7509+
// cblt w0, #63, A
7510+
// because we could end up with the illegal case
7511+
// cbge w0, #64, B
7512+
// should the decision to reverse the branch direction be made. For the
7513+
// lower bound cases this is no problem since we can express comparisons
7514+
// against 0 with either tbz/tnbz or using wzr/xzr.
7515+
uint64_t LowerBound = 0, UpperBound = 64;
7516+
switch (CC) {
7517+
case AArch64CC::GE:
7518+
case AArch64CC::HS:
7519+
case AArch64CC::LT:
7520+
case AArch64CC::LO:
7521+
LowerBound = 1;
7522+
break;
7523+
case AArch64CC::LE:
7524+
case AArch64CC::LS:
7525+
case AArch64CC::GT:
7526+
case AArch64CC::HI:
7527+
UpperBound = 63;
7528+
break;
7529+
default:
7530+
break;
7531+
}
7532+
7533+
if (CN->getAPIntValue().uge(LowerBound) &&
7534+
CN->getAPIntValue().ult(UpperBound)) {
7535+
SDLoc DL(N);
7536+
Imm = CurDAG->getTargetConstant(CN->getZExtValue(), DL, N.getValueType());
7537+
return true;
7538+
}
7539+
}
7540+
7541+
return false;
7542+
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2993,6 +2993,7 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
29932993
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
29942994
MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
29952995
MAKE_CASE(AArch64ISD::URSHR_I_PRED)
2996+
MAKE_CASE(AArch64ISD::CB)
29962997
}
29972998
#undef MAKE_CASE
29982999
return nullptr;
@@ -10603,6 +10604,17 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
1060310604
DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
1060410605
}
1060510606

10607+
// Try to emit Armv9.6 CB instructions. We prefer tb{n}z/cb{n}z due to their
10608+
// larger branch displacement but do prefer CB over cmp + br.
10609+
if (Subtarget->hasCMPBR() &&
10610+
AArch64CC::isValidCBCond(changeIntCCToAArch64CC(CC)) &&
10611+
ProduceNonFlagSettingCondBr) {
10612+
SDValue Cond =
10613+
DAG.getTargetConstant(changeIntCCToAArch64CC(CC), dl, MVT::i32);
10614+
return DAG.getNode(AArch64ISD::CB, dl, MVT::Other, Chain, Cond, LHS, RHS,
10615+
Dest);
10616+
}
10617+
1060610618
SDValue CCVal;
1060710619
SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
1060810620
return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,9 @@ enum NodeType : unsigned {
529529
// SME ZA loads and stores
530530
SME_ZA_LDR,
531531
SME_ZA_STR,
532+
533+
// Compare-and-branch
534+
CB,
532535
};
533536

534537
} // end namespace AArch64ISD

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,16 @@ def uimm6_32b : Operand<i32>, ImmLeaf<i32, [{ return Imm >= 0 && Imm < 64; }]> {
400400
let ParserMatchClass = UImm6Operand;
401401
}
402402

403+
def CmpBranchUImm6Operand_32b
404+
: ComplexPattern<i32, 1, "SelectCmpBranchUImm6Operand", [imm]> {
405+
let WantsParent = true;
406+
}
407+
408+
def CmpBranchUImm6Operand_64b
409+
: ComplexPattern<i64, 1, "SelectCmpBranchUImm6Operand", [imm]> {
410+
let WantsParent = true;
411+
}
412+
403413
def UImm6Plus1Operand : AsmOperandClass {
404414
let Name = "UImm6P1";
405415
let DiagnosticType = "InvalidImm1_64";
@@ -13225,6 +13235,21 @@ multiclass CmpBranchRegisterAlias<string mnemonic, string insn> {
1322513235
def : InstAlias<mnemonic # "\t$Rt, $Rm, $target",
1322613236
(!cast<Instruction>(insn # "Xrr") GPR64:$Rm, GPR64:$Rt, am_brcmpcond:$target), 0>;
1322713237
}
13238+
13239+
class CmpBranchRegisterPseudo<RegisterClass regtype>
13240+
: Pseudo<(outs), (ins ccode:$Cond, regtype:$Rt, regtype:$Rm, am_brcmpcond:$Target), []>,
13241+
Sched<[WriteBr]> {
13242+
let isBranch = 1;
13243+
let isTerminator = 1;
13244+
}
13245+
13246+
class CmpBranchImmediatePseudo<RegisterClass regtype, ImmLeaf imtype>
13247+
: Pseudo<(outs), (ins ccode:$Cond, regtype:$Rt, imtype:$Imm, am_brcmpcond:$Target), []>,
13248+
Sched<[WriteBr]> {
13249+
let isBranch = 1;
13250+
let isTerminator = 1;
13251+
}
13252+
1322813253
//----------------------------------------------------------------------------
1322913254
// Allow the size specifier tokens to be upper case, not just lower.
1323013255
def : TokenAlias<".4B", ".4b">; // Add dot product

0 commit comments

Comments
 (0)