Skip to content

Commit 55a4f0f

Browse files
committed
[AArch64][FEAT_CMPBR] Codegen for Armv9.6-a compare-and-branch
This patch adds codegen for all Arm9.6-a compare-and-branch instructions, that operate on full w or x registers. The instruction variants operating on half-words (cbh) and bytes (cbb) are added in a subsequent patch. Since CB doesn't use standard 4-bit Arm condition codes but a reduced set of conditions, encoded in 3 bits, some conditions are expressed by modifying operands, namely incrementing or decrementing immediate operands and swapping register operands. To invert a CB instruction it's therefore not enough to just modify the condition code which doesn't play particularly well with how the backend is currently organized. We therefore introduce a number of pseudos which operate on the standard 4-bit condition codes and lower them late during codegen.
1 parent e9a20f7 commit 55a4f0f

14 files changed

+2973
-2
lines changed

llvm/lib/Target/AArch64/AArch64AsmPrinter.cpp

Lines changed: 165 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ class AArch64AsmPrinter : public AsmPrinter {
208208
void emitAttributes(unsigned Flags, uint64_t PAuthABIPlatform,
209209
uint64_t PAuthABIVersion, AArch64TargetStreamer *TS);
210210

211+
// Emit expansion of Compare-and-branch pseudo instructions
212+
void emitCBPseudoExpansion(const MachineInstr *MI);
213+
211214
void EmitToStreamer(MCStreamer &S, const MCInst &Inst);
212215
void EmitToStreamer(const MCInst &Inst) {
213216
EmitToStreamer(*OutStreamer, Inst);
@@ -2589,6 +2592,160 @@ AArch64AsmPrinter::lowerBlockAddressConstant(const BlockAddress &BA) {
25892592
return BAE;
25902593
}
25912594

2595+
void AArch64AsmPrinter::emitCBPseudoExpansion(const MachineInstr *MI) {
2596+
bool IsImm = false;
2597+
bool Is32Bit = false;
2598+
2599+
switch (MI->getOpcode()) {
2600+
default:
2601+
llvm_unreachable("This is not a CB pseudo instruction");
2602+
case AArch64::CBWPrr:
2603+
IsImm = false;
2604+
Is32Bit = true;
2605+
break;
2606+
case AArch64::CBXPrr:
2607+
IsImm = false;
2608+
Is32Bit = false;
2609+
break;
2610+
case AArch64::CBWPri:
2611+
IsImm = true;
2612+
Is32Bit = true;
2613+
break;
2614+
case AArch64::CBXPri:
2615+
IsImm = true;
2616+
Is32Bit = false;
2617+
break;
2618+
}
2619+
2620+
AArch64CC::CondCode CC =
2621+
static_cast<AArch64CC::CondCode>(MI->getOperand(0).getImm());
2622+
bool NeedsRegSwap = false;
2623+
bool NeedsImmDec = false;
2624+
bool NeedsImmInc = false;
2625+
2626+
// Decide if we need to either swap register operands or increment/decrement
2627+
// immediate operands
2628+
unsigned MCOpC;
2629+
switch (CC) {
2630+
default:
2631+
llvm_unreachable("Invalid CB condition code");
2632+
case AArch64CC::EQ:
2633+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBEQWri : AArch64::CBEQXri)
2634+
: (Is32Bit ? AArch64::CBEQWrr : AArch64::CBEQXrr);
2635+
NeedsRegSwap = false;
2636+
NeedsImmDec = false;
2637+
NeedsImmInc = false;
2638+
break;
2639+
case AArch64CC::NE:
2640+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBNEWri : AArch64::CBNEXri)
2641+
: (Is32Bit ? AArch64::CBNEWrr : AArch64::CBNEXrr);
2642+
NeedsRegSwap = false;
2643+
NeedsImmDec = false;
2644+
NeedsImmInc = false;
2645+
break;
2646+
case AArch64CC::HS:
2647+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri)
2648+
: (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr);
2649+
NeedsRegSwap = false;
2650+
NeedsImmDec = IsImm;
2651+
NeedsImmInc = false;
2652+
break;
2653+
case AArch64CC::LO:
2654+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri)
2655+
: (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr);
2656+
NeedsRegSwap = !IsImm;
2657+
NeedsImmDec = false;
2658+
NeedsImmInc = false;
2659+
break;
2660+
case AArch64CC::HI:
2661+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBHIWri : AArch64::CBHIXri)
2662+
: (Is32Bit ? AArch64::CBHIWrr : AArch64::CBHIXrr);
2663+
NeedsRegSwap = false;
2664+
NeedsImmDec = false;
2665+
NeedsImmInc = false;
2666+
break;
2667+
case AArch64CC::LS:
2668+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLOWri : AArch64::CBLOXri)
2669+
: (Is32Bit ? AArch64::CBHSWrr : AArch64::CBHSXrr);
2670+
NeedsRegSwap = !IsImm;
2671+
NeedsImmDec = false;
2672+
NeedsImmInc = IsImm;
2673+
break;
2674+
case AArch64CC::GE:
2675+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri)
2676+
: (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr);
2677+
NeedsRegSwap = false;
2678+
NeedsImmDec = IsImm;
2679+
NeedsImmInc = false;
2680+
break;
2681+
case AArch64CC::LT:
2682+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri)
2683+
: (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr);
2684+
NeedsRegSwap = !IsImm;
2685+
NeedsImmDec = false;
2686+
NeedsImmInc = false;
2687+
break;
2688+
case AArch64CC::GT:
2689+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBGTWri : AArch64::CBGTXri)
2690+
: (Is32Bit ? AArch64::CBGTWrr : AArch64::CBGTXrr);
2691+
NeedsRegSwap = false;
2692+
NeedsImmDec = false;
2693+
NeedsImmInc = false;
2694+
break;
2695+
case AArch64CC::LE:
2696+
MCOpC = IsImm ? (Is32Bit ? AArch64::CBLTWri : AArch64::CBLTXri)
2697+
: (Is32Bit ? AArch64::CBGEWrr : AArch64::CBGEXrr);
2698+
NeedsRegSwap = !IsImm;
2699+
NeedsImmDec = false;
2700+
NeedsImmInc = IsImm;
2701+
break;
2702+
}
2703+
2704+
assert(!(NeedsImmDec && NeedsImmInc) &&
2705+
"Cannot require increment and decrement of CB immediate operand at "
2706+
"the same time");
2707+
2708+
MCInst Inst;
2709+
Inst.setOpcode(MCOpC);
2710+
2711+
MCOperand Lhs, Rhs, Trgt;
2712+
lowerOperand(MI->getOperand(1), Lhs);
2713+
lowerOperand(MI->getOperand(2), Rhs);
2714+
lowerOperand(MI->getOperand(3), Trgt);
2715+
2716+
// Now swap, increment or decrement
2717+
if (NeedsRegSwap) {
2718+
assert(
2719+
!IsImm &&
2720+
"Unexpected register swap for CB instruction with immediate operand");
2721+
assert(Lhs.isReg() && "Expected register operand for CB");
2722+
assert(Rhs.isReg() && "Expected register operand for CB");
2723+
Inst.addOperand(Rhs);
2724+
Inst.addOperand(Lhs);
2725+
} else if (NeedsImmDec) {
2726+
assert(IsImm && "Unexpected immediate decrement for CB instruction with "
2727+
"reg-reg operands");
2728+
Rhs.setImm(Rhs.getImm() - 1);
2729+
Inst.addOperand(Lhs);
2730+
Inst.addOperand(Rhs);
2731+
} else if (NeedsImmInc) {
2732+
assert(IsImm && "Unexpected immediate increment for CB instruction with "
2733+
"reg-reg operands");
2734+
Rhs.setImm(Rhs.getImm() + 1);
2735+
Inst.addOperand(Lhs);
2736+
Inst.addOperand(Rhs);
2737+
} else {
2738+
Inst.addOperand(Lhs);
2739+
Inst.addOperand(Rhs);
2740+
}
2741+
2742+
assert((!IsImm || (Rhs.getImm() >= 0 && Rhs.getImm() < 64)) &&
2743+
"CB immediate operand out-of-bounds");
2744+
2745+
Inst.addOperand(Trgt);
2746+
EmitToStreamer(*OutStreamer, Inst);
2747+
}
2748+
25922749
// Simple pseudo-instructions have their lowering (with expansion to real
25932750
// instructions) auto-generated.
25942751
#include "AArch64GenMCPseudoLowering.inc"
@@ -3155,13 +3312,20 @@ void AArch64AsmPrinter::emitInstruction(const MachineInstr *MI) {
31553312
return;
31563313

31573314
case AArch64::BLR:
3158-
case AArch64::BR:
3315+
case AArch64::BR: {
31593316
recordIfImportCall(MI);
31603317
MCInst TmpInst;
31613318
MCInstLowering.Lower(MI, TmpInst);
31623319
EmitToStreamer(*OutStreamer, TmpInst);
31633320
return;
31643321
}
3322+
case AArch64::CBWPri:
3323+
case AArch64::CBXPri:
3324+
case AArch64::CBWPrr:
3325+
case AArch64::CBXPrr:
3326+
emitCBPseudoExpansion(MI);
3327+
return;
3328+
}
31653329

31663330
// Finally, do the automated lowerings for everything else.
31673331
MCInst TmpInst;

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2983,6 +2983,8 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
29832983
MAKE_CASE(AArch64ISD::CTTZ_ELTS)
29842984
MAKE_CASE(AArch64ISD::CALL_ARM64EC_TO_X64)
29852985
MAKE_CASE(AArch64ISD::URSHR_I_PRED)
2986+
MAKE_CASE(AArch64ISD::CBRR)
2987+
MAKE_CASE(AArch64ISD::CBRI)
29862988
}
29872989
#undef MAKE_CASE
29882990
return nullptr;
@@ -10593,6 +10595,56 @@ SDValue AArch64TargetLowering::LowerBR_CC(SDValue Op, SelectionDAG &DAG) const {
1059310595
DAG.getConstant(SignBitPos, dl, MVT::i64), Dest);
1059410596
}
1059510597

10598+
// Try to emit Armv9.6 CB instructions. We prefer tb{n}z/cb{n}z due to their
10599+
// larger branch displacement but do prefer CB over cmp + br.
10600+
if (Subtarget->hasCMPBR() &&
10601+
AArch64CC::isValidCBCond(changeIntCCToAArch64CC(CC)) &&
10602+
ProduceNonFlagSettingCondBr) {
10603+
AArch64CC::CondCode ACC = changeIntCCToAArch64CC(CC);
10604+
unsigned Opc = AArch64ISD::CBRR;
10605+
if (auto *Imm = dyn_cast<ConstantSDNode>(RHS)) {
10606+
// Check conservatively if the immediate fits the valid range [0, 64).
10607+
// Immediate variants for GE and HS definitely need to be decremented
10608+
// when lowering the pseudos later, so an immediate of 1 would become 0.
10609+
// For the inverse conditions LT and LO we don't know for sure if they
10610+
// will need a decrement but should the decision be made to reverse the
10611+
// branch condition, we again end up with the need to decrement.
10612+
// The same argument holds for LE, LS, GT and HI and possibly
10613+
// incremented immediates. This can lead to slightly less optimal
10614+
// codegen, e.g. we never codegen the legal case
10615+
// cblt w0, #63, A
10616+
// because we could end up with the illegal case
10617+
// cbge w0, #64, B
10618+
// should the decision to reverse the branch direction be made. For the
10619+
// lower bound cases this is no problem since we can express comparisons
10620+
// against 0 with either tbz/tnbz or using wzr/xzr.
10621+
uint64_t LowerBound = 0, UpperBound = 64;
10622+
switch (ACC) {
10623+
case AArch64CC::GE:
10624+
case AArch64CC::HS:
10625+
case AArch64CC::LT:
10626+
case AArch64CC::LO:
10627+
LowerBound = 1;
10628+
break;
10629+
case AArch64CC::LE:
10630+
case AArch64CC::LS:
10631+
case AArch64CC::GT:
10632+
case AArch64CC::HI:
10633+
UpperBound = 63;
10634+
break;
10635+
default:
10636+
break;
10637+
}
10638+
10639+
if (Imm->getAPIntValue().uge(LowerBound) &&
10640+
Imm->getAPIntValue().ult(UpperBound))
10641+
Opc = AArch64ISD::CBRI;
10642+
}
10643+
10644+
SDValue Cond = DAG.getTargetConstant(ACC, dl, MVT::i32);
10645+
return DAG.getNode(Opc, dl, MVT::Other, Chain, Cond, LHS, RHS, Dest);
10646+
}
10647+
1059610648
SDValue CCVal;
1059710649
SDValue Cmp = getAArch64Cmp(LHS, RHS, CC, CCVal, DAG, dl);
1059810650
return DAG.getNode(AArch64ISD::BRCOND, dl, MVT::Other, Chain, Dest, CCVal,

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,10 @@ enum NodeType : unsigned {
529529
// SME ZA loads and stores
530530
SME_ZA_LDR,
531531
SME_ZA_STR,
532+
533+
// Compare-and-branch
534+
CBRR,
535+
CBRI,
532536
};
533537

534538
} // end namespace AArch64ISD

llvm/lib/Target/AArch64/AArch64InstrFormats.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13232,6 +13232,21 @@ multiclass CmpBranchRegisterAlias<string mnemonic, string insn> {
1323213232
def : InstAlias<mnemonic # "\t$Rt, $Rm, $target",
1323313233
(!cast<Instruction>(insn # "Xrr") GPR64:$Rm, GPR64:$Rt, am_brcmpcond:$target), 0>;
1323413234
}
13235+
13236+
class CmpBranchRegisterPseudo<RegisterClass regtype>
13237+
: Pseudo<(outs), (ins ccode:$Cond, regtype:$Rt, regtype:$Rm, am_brcmpcond:$Target), []>,
13238+
Sched<[WriteBr]> {
13239+
let isBranch = 1;
13240+
let isTerminator = 1;
13241+
}
13242+
13243+
class CmpBranchImmediatePseudo<RegisterClass regtype, ImmLeaf imtype>
13244+
: Pseudo<(outs), (ins ccode:$Cond, regtype:$Rt, imtype:$Imm, am_brcmpcond:$Target), []>,
13245+
Sched<[WriteBr]> {
13246+
let isBranch = 1;
13247+
let isTerminator = 1;
13248+
}
13249+
1323513250
//----------------------------------------------------------------------------
1323613251
// Allow the size specifier tokens to be upper case, not just lower.
1323713252
def : TokenAlias<".4B", ".4b">; // Add dot product

0 commit comments

Comments
 (0)