Skip to content

[LLVM][AArch64] Refactor lowering of fixed length integer setcc operations. #132434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 42 additions & 91 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2057,6 +2057,15 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
setOperationAction(ISD::READ_REGISTER, MVT::i128, Custom);
setOperationAction(ISD::WRITE_REGISTER, MVT::i128, Custom);
}

if (VT.isInteger()) {
// Let common code emit inverted variants of compares we do support.
setCondCodeAction(ISD::SETNE, VT, Expand);
setCondCodeAction(ISD::SETLE, VT, Expand);
setCondCodeAction(ISD::SETLT, VT, Expand);
setCondCodeAction(ISD::SETULE, VT, Expand);
setCondCodeAction(ISD::SETULT, VT, Expand);
}
}

bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
Expand Down Expand Up @@ -2581,31 +2590,21 @@ unsigned AArch64TargetLowering::ComputeNumSignBitsForTargetNode(
unsigned VTBits = VT.getScalarSizeInBits();
unsigned Opcode = Op.getOpcode();
switch (Opcode) {
case AArch64ISD::CMEQ:
case AArch64ISD::CMGE:
case AArch64ISD::CMGT:
case AArch64ISD::CMHI:
case AArch64ISD::CMHS:
case AArch64ISD::FCMEQ:
case AArch64ISD::FCMGE:
case AArch64ISD::FCMGT:
case AArch64ISD::CMEQz:
case AArch64ISD::CMGEz:
case AArch64ISD::CMGTz:
case AArch64ISD::CMLEz:
case AArch64ISD::CMLTz:
case AArch64ISD::FCMEQz:
case AArch64ISD::FCMGEz:
case AArch64ISD::FCMGTz:
case AArch64ISD::FCMLEz:
case AArch64ISD::FCMLTz:
// Compares return either 0 or all-ones
return VTBits;
case AArch64ISD::VASHR: {
unsigned Tmp =
DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
return std::min<uint64_t>(Tmp + Op.getConstantOperandVal(1), VTBits);
}
case AArch64ISD::FCMEQ:
case AArch64ISD::FCMGE:
case AArch64ISD::FCMGT:
case AArch64ISD::FCMEQz:
case AArch64ISD::FCMGEz:
case AArch64ISD::FCMGTz:
case AArch64ISD::FCMLEz:
case AArch64ISD::FCMLTz:
// Compares return either 0 or all-ones
return VTBits;
case AArch64ISD::VASHR: {
unsigned Tmp =
DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
return std::min<uint64_t>(Tmp + Op.getConstantOperandVal(1), VTBits);
}
}

return 1;
Expand Down Expand Up @@ -2812,19 +2811,9 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
MAKE_CASE(AArch64ISD::VASHR)
MAKE_CASE(AArch64ISD::VSLI)
MAKE_CASE(AArch64ISD::VSRI)
MAKE_CASE(AArch64ISD::CMEQ)
MAKE_CASE(AArch64ISD::CMGE)
MAKE_CASE(AArch64ISD::CMGT)
MAKE_CASE(AArch64ISD::CMHI)
MAKE_CASE(AArch64ISD::CMHS)
MAKE_CASE(AArch64ISD::FCMEQ)
MAKE_CASE(AArch64ISD::FCMGE)
MAKE_CASE(AArch64ISD::FCMGT)
MAKE_CASE(AArch64ISD::CMEQz)
MAKE_CASE(AArch64ISD::CMGEz)
MAKE_CASE(AArch64ISD::CMGTz)
MAKE_CASE(AArch64ISD::CMLEz)
MAKE_CASE(AArch64ISD::CMLTz)
MAKE_CASE(AArch64ISD::FCMEQz)
MAKE_CASE(AArch64ISD::FCMGEz)
MAKE_CASE(AArch64ISD::FCMGTz)
Expand Down Expand Up @@ -15814,9 +15803,6 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
SplatBitSize, HasAnyUndefs);

bool IsZero = IsCnst && SplatValue == 0;
bool IsOne =
IsCnst && SrcVT.getScalarSizeInBits() == SplatBitSize && SplatValue == 1;
bool IsMinusOne = IsCnst && SplatValue.isAllOnes();

if (SrcVT.getVectorElementType().isFloatingPoint()) {
switch (CC) {
Expand Down Expand Up @@ -15863,50 +15849,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
}
}

switch (CC) {
default:
return SDValue();
case AArch64CC::NE: {
SDValue Cmeq;
if (IsZero)
Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
else
Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
return DAG.getNOT(dl, Cmeq, VT);
}
case AArch64CC::EQ:
if (IsZero)
return DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
case AArch64CC::GE:
if (IsZero)
return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGE, dl, VT, LHS, RHS);
case AArch64CC::GT:
if (IsZero)
return DAG.getNode(AArch64ISD::CMGTz, dl, VT, LHS);
if (IsMinusOne)
return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGT, dl, VT, LHS, RHS);
case AArch64CC::LE:
if (IsZero)
return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGE, dl, VT, RHS, LHS);
case AArch64CC::LS:
return DAG.getNode(AArch64ISD::CMHS, dl, VT, RHS, LHS);
case AArch64CC::LO:
return DAG.getNode(AArch64ISD::CMHI, dl, VT, RHS, LHS);
case AArch64CC::LT:
if (IsZero)
return DAG.getNode(AArch64ISD::CMLTz, dl, VT, LHS);
if (IsOne)
return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
return DAG.getNode(AArch64ISD::CMGT, dl, VT, RHS, LHS);
case AArch64CC::HI:
return DAG.getNode(AArch64ISD::CMHI, dl, VT, LHS, RHS);
case AArch64CC::HS:
return DAG.getNode(AArch64ISD::CMHS, dl, VT, LHS, RHS);
}
return SDValue();
}

SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
Expand All @@ -15924,13 +15867,8 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
EVT CmpVT = LHS.getValueType().changeVectorElementTypeToInteger();
SDLoc dl(Op);

if (LHS.getValueType().getVectorElementType().isInteger()) {
assert(LHS.getValueType() == RHS.getValueType());
AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
SDValue Cmp =
EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG);
return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
}
if (LHS.getValueType().getVectorElementType().isInteger())
return Op;

// Lower isnan(x) | isnan(never-nan) to x != x.
// Lower !isnan(x) & !isnan(never-nan) to x == x.
Expand Down Expand Up @@ -18128,7 +18066,9 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
if (!ShiftAmt || ShiftAmt->getZExtValue() != ShiftEltTy.getSizeInBits() - 1)
return SDValue();

return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
SDLoc DL(N);
SDValue Zero = DAG.getConstant(0, DL, Shift.getValueType());
return DAG.getSetCC(DL, VT, Shift.getOperand(0), Zero, ISD::SETGE);
}

// Given a vecreduce_add node, detect the below pattern and convert it to the
Expand Down Expand Up @@ -18739,7 +18679,8 @@ static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {

SDLoc DL(N);
SDValue In = DAG.getNode(AArch64ISD::NVCAST, DL, HalfVT, Srl.getOperand(0));
SDValue CM = DAG.getNode(AArch64ISD::CMLTz, DL, HalfVT, In);
SDValue Zero = DAG.getConstant(0, DL, In.getValueType());
SDValue CM = DAG.getSetCC(DL, HalfVT, Zero, In, ISD::SETGT);
return DAG.getNode(AArch64ISD::NVCAST, DL, VT, CM);
}

Expand Down Expand Up @@ -25268,6 +25209,16 @@ static SDValue performSETCCCombine(SDNode *N,
if (SDValue V = performOrXorChainCombine(N, DAG))
return V;

EVT CmpVT = LHS.getValueType();

// NOTE: This exists as a combine only because it proved too awkward to match
// splat(1) across all the NEON types during isel.
APInt SplatLHSVal;
if (CmpVT.isInteger() && Cond == ISD::SETGT &&
ISD::isConstantSplatVector(LHS.getNode(), SplatLHSVal) &&
SplatLHSVal.isOne())
return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, CmpVT), RHS, ISD::SETGE);

return SDValue();
}

Expand Down
10 changes: 0 additions & 10 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,21 +241,11 @@ enum NodeType : unsigned {
VSRI,

// Vector comparisons
CMEQ,
CMGE,
CMGT,
CMHI,
CMHS,
FCMEQ,
FCMGE,
FCMGT,

// Vector zero comparisons
CMEQz,
CMGEz,
CMGTz,
CMLEz,
CMLTz,
FCMEQz,
FCMGEz,
FCMGTz,
Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -7086,7 +7086,7 @@ multiclass SIMD_FP8_CVTL<bits<2>sz, string asm, ValueType dty, SDPatternOperator
class BaseSIMDCmpTwoVector<bit Q, bit U, bits<2> size, bits<2> size2,
bits<5> opcode, RegisterOperand regtype, string asm,
string kind, string zero, ValueType dty,
ValueType sty, SDNode OpNode>
ValueType sty, SDPatternOperator OpNode>
: I<(outs regtype:$Rd), (ins regtype:$Rn), asm,
"{\t$Rd" # kind # ", $Rn" # kind # ", #" # zero #
"|" # kind # "\t$Rd, $Rn, #" # zero # "}", "",
Expand All @@ -7110,7 +7110,7 @@ class BaseSIMDCmpTwoVector<bit Q, bit U, bits<2> size, bits<2> size2,

// Comparisons support all element sizes, except 1xD.
multiclass SIMDCmpTwoVector<bit U, bits<5> opc, string asm,
SDNode OpNode> {
SDPatternOperator OpNode> {
def v8i8rz : BaseSIMDCmpTwoVector<0, U, 0b00, 0b00, opc, V64,
asm, ".8b", "0",
v8i8, v8i8, OpNode>;
Expand Down Expand Up @@ -7981,7 +7981,7 @@ multiclass SIMDCmpTwoScalarD<bit U, bits<5> opc, string asm,
SDPatternOperator OpNode> {
def v1i64rz : BaseSIMDCmpTwoScalar<U, 0b11, 0b00, opc, FPR64, asm, "0">;

def : Pat<(v1i64 (OpNode FPR64:$Rn)),
def : Pat<(v1i64 (OpNode v1i64:$Rn)),
(!cast<Instruction>(NAME # v1i64rz) FPR64:$Rn)>;
}

Expand Down
36 changes: 24 additions & 12 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -846,23 +846,35 @@ def AArch64vsri : SDNode<"AArch64ISD::VSRI", SDT_AArch64vshiftinsert>;

def AArch64bsp: SDNode<"AArch64ISD::BSP", SDT_AArch64trivec>;

def AArch64cmeq: SDNode<"AArch64ISD::CMEQ", SDT_AArch64binvec>;
def AArch64cmge: SDNode<"AArch64ISD::CMGE", SDT_AArch64binvec>;
def AArch64cmgt: SDNode<"AArch64ISD::CMGT", SDT_AArch64binvec>;
def AArch64cmhi: SDNode<"AArch64ISD::CMHI", SDT_AArch64binvec>;
def AArch64cmhs: SDNode<"AArch64ISD::CMHS", SDT_AArch64binvec>;
def AArch64cmeq : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETEQ)>;
def AArch64cmge : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETGE)>;
def AArch64cmgt : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETGT)>;
def AArch64cmhi : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETUGT)>;
def AArch64cmhs : PatFrag<(ops node:$lhs, node:$rhs),
(setcc node:$lhs, node:$rhs, SETUGE)>;

def AArch64fcmeq: SDNode<"AArch64ISD::FCMEQ", SDT_AArch64fcmp>;
def AArch64fcmge: SDNode<"AArch64ISD::FCMGE", SDT_AArch64fcmp>;
def AArch64fcmgt: SDNode<"AArch64ISD::FCMGT", SDT_AArch64fcmp>;

def AArch64cmeqz: SDNode<"AArch64ISD::CMEQz", SDT_AArch64unvec>;
def AArch64cmgez: SDNode<"AArch64ISD::CMGEz", SDT_AArch64unvec>;
def AArch64cmgtz: SDNode<"AArch64ISD::CMGTz", SDT_AArch64unvec>;
def AArch64cmlez: SDNode<"AArch64ISD::CMLEz", SDT_AArch64unvec>;
def AArch64cmltz: SDNode<"AArch64ISD::CMLTz", SDT_AArch64unvec>;
def AArch64cmeqz : PatFrag<(ops node:$lhs),
(setcc node:$lhs, immAllZerosV, SETEQ)>;
def AArch64cmgez : PatFrags<(ops node:$lhs),
[(setcc node:$lhs, immAllZerosV, SETGE),
(setcc node:$lhs, immAllOnesV, SETGT)]>;
def AArch64cmgtz : PatFrag<(ops node:$lhs),
(setcc node:$lhs, immAllZerosV, SETGT)>;
def AArch64cmlez : PatFrag<(ops node:$lhs),
(setcc immAllZerosV, node:$lhs, SETGE)>;
def AArch64cmltz : PatFrag<(ops node:$lhs),
(setcc immAllZerosV, node:$lhs, SETGT)>;

def AArch64cmtst : PatFrag<(ops node:$LHS, node:$RHS),
(vnot (AArch64cmeqz (and node:$LHS, node:$RHS)))>;
(vnot (AArch64cmeqz (and node:$LHS, node:$RHS)))>;

def AArch64fcmeqz: SDNode<"AArch64ISD::FCMEQz", SDT_AArch64fcmpz>;
def AArch64fcmgez: SDNode<"AArch64ISD::FCMGEz", SDT_AArch64fcmpz>;
Expand Down Expand Up @@ -5671,7 +5683,7 @@ defm CMHI : SIMDThreeSameVector<1, 0b00110, "cmhi", AArch64cmhi>;
defm CMHS : SIMDThreeSameVector<1, 0b00111, "cmhs", AArch64cmhs>;
defm CMTST : SIMDThreeSameVector<0, 0b10001, "cmtst", AArch64cmtst>;
foreach VT = [ v8i8, v16i8, v4i16, v8i16, v2i32, v4i32, v2i64 ] in {
def : Pat<(vnot (AArch64cmeqz VT:$Rn)), (!cast<Instruction>("CMTST"#VT) VT:$Rn, VT:$Rn)>;
def : Pat<(VT (vnot (AArch64cmeqz VT:$Rn))), (!cast<Instruction>("CMTST"#VT) VT:$Rn, VT:$Rn)>;
}
defm FABD : SIMDThreeSameVectorFP<1,1,0b010,"fabd", int_aarch64_neon_fabd>;
let Predicates = [HasNEON] in {
Expand Down
15 changes: 7 additions & 8 deletions llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
Original file line number Diff line number Diff line change
Expand Up @@ -352,17 +352,16 @@ define void @typei1_orig(i64 %a, ptr %p, ptr %q) {
;
; CHECK-GI-LABEL: typei1_orig:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: ldr q1, [x2]
; CHECK-GI-NEXT: ldr q0, [x2]
; CHECK-GI-NEXT: cmp x0, #0
; CHECK-GI-NEXT: movi v0.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cset w8, gt
; CHECK-GI-NEXT: neg v1.8h, v1.8h
; CHECK-GI-NEXT: dup v2.8h, w8
; CHECK-GI-NEXT: mvn v0.16b, v0.16b
; CHECK-GI-NEXT: mul v1.8h, v1.8h, v2.8h
; CHECK-GI-NEXT: cmeq v1.8h, v1.8h, #0
; CHECK-GI-NEXT: neg v0.8h, v0.8h
; CHECK-GI-NEXT: dup v1.8h, w8
; CHECK-GI-NEXT: mul v0.8h, v0.8h, v1.8h
; CHECK-GI-NEXT: movi v1.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cmtst v0.8h, v0.8h, v0.8h
; CHECK-GI-NEXT: mvn v1.16b, v1.16b
; CHECK-GI-NEXT: uzp1 v0.16b, v1.16b, v0.16b
; CHECK-GI-NEXT: uzp1 v0.16b, v0.16b, v1.16b
; CHECK-GI-NEXT: shl v0.16b, v0.16b, #7
; CHECK-GI-NEXT: sshr v0.16b, v0.16b, #7
; CHECK-GI-NEXT: str q0, [x1]
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2382,11 +2382,11 @@ define <2 x i1> @test_signed_v2f64_v2i1(<2 x double> %f) {
; CHECK-GI-LABEL: test_signed_v2f64_v2i1:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: fcvtzs v0.2d, v0.2d
; CHECK-GI-NEXT: movi v2.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cmlt v1.2d, v0.2d, #0
; CHECK-GI-NEXT: and v0.16b, v0.16b, v1.16b
; CHECK-GI-NEXT: cmgt v1.2d, v0.2d, v2.2d
; CHECK-GI-NEXT: bif v0.16b, v2.16b, v1.16b
; CHECK-GI-NEXT: movi v1.2d, #0xffffffffffffffff
; CHECK-GI-NEXT: cmge v2.2d, v0.2d, #0
; CHECK-GI-NEXT: bif v0.16b, v1.16b, v2.16b
; CHECK-GI-NEXT: xtn v0.2s, v0.2d
; CHECK-GI-NEXT: ret
%x = call <2 x i1> @llvm.fptosi.sat.v2f64.v2i1(<2 x double> %f)
Expand Down
18 changes: 5 additions & 13 deletions llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1499,8 +1499,7 @@ define <8 x i8> @vselect_cmpz_ne(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
;
; CHECK-GI-LABEL: vselect_cmpz_ne:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: cmeq v0.8b, v0.8b, #0
; CHECK-GI-NEXT: mvn v0.8b, v0.8b
; CHECK-GI-NEXT: cmtst v0.8b, v0.8b, v0.8b
; CHECK-GI-NEXT: bsl v0.8b, v1.8b, v2.8b
; CHECK-GI-NEXT: ret
%cmp = icmp ne <8 x i8> %a, zeroinitializer
Expand Down Expand Up @@ -1533,17 +1532,10 @@ define <8 x i8> @vselect_tst(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
}

define <8 x i8> @sext_tst(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
; CHECK-SD-LABEL: sext_tst:
; CHECK-SD: // %bb.0:
; CHECK-SD-NEXT: cmtst v0.8b, v0.8b, v1.8b
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: sext_tst:
; CHECK-GI: // %bb.0:
; CHECK-GI-NEXT: and v0.8b, v0.8b, v1.8b
; CHECK-GI-NEXT: cmeq v0.8b, v0.8b, #0
; CHECK-GI-NEXT: mvn v0.8b, v0.8b
; CHECK-GI-NEXT: ret
; CHECK-LABEL: sext_tst:
; CHECK: // %bb.0:
; CHECK-NEXT: cmtst v0.8b, v0.8b, v1.8b
; CHECK-NEXT: ret
%tmp3 = and <8 x i8> %a, %b
%tmp4 = icmp ne <8 x i8> %tmp3, zeroinitializer
%d = sext <8 x i1> %tmp4 to <8 x i8>
Expand Down
Loading