Skip to content

[LLVM][AArch64] Refactor lowering of fixed length integer setcc operations. #132434

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 4, 2025

Conversation

paulwalker-arm
Copy link
Collaborator

The original code is essentially performing isel during legalisation with the AArch64 specific nodes offering no additional value compared to ISD::SETCC.

Whilst not the motivating case, the effect of removing the indirection means global-isel no longer misses out on the custom handling.

If agreeable I hope to follow this with matching refactoring of the floating point based setcc operations.

…tions.

The original code is essentially performing isel during legalisation
with the AArch64 specific nodes offering no additional value compared
to ISD::SETCC.

Whilst not the motivating case, the effect of removing the indirection
means global-isel no longer misses out on the custom handling.

If agreeable I hope to follow this with matching refactoring of the
floating point based setcc operations.
@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-backend-aarch64

Author: Paul Walker (paulwalker-arm)

Changes

The original code is essentially performing isel during legalisation with the AArch64 specific nodes offering no additional value compared to ISD::SETCC.

Whilst not the motivating case, the effect of removing the indirection means global-isel no longer misses out on the custom handling.

If agreeable I hope to follow this with matching refactoring of the floating point based setcc operations.


Patch is 36.61 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/132434.diff

12 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+43-87)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (-10)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrFormats.td (+3-3)
  • (modified) llvm/lib/Target/AArch64/AArch64InstrInfo.td (+24-12)
  • (modified) llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll (+7-8)
  • (modified) llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll (+3-3)
  • (modified) llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll (+5-13)
  • (modified) llvm/test/CodeGen/AArch64/neon-compare-instructions.ll (+84-217)
  • (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-extract-subvector.ll (+2-3)
  • (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-gather.ll (+2-2)
  • (modified) llvm/test/CodeGen/AArch64/sve-fixed-length-masked-scatter.ll (+1-2)
  • (modified) llvm/test/CodeGen/AArch64/vec-combine-compare-to-bitmask.ll (+12-27)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0db6c614684d7..c8922c4a1d5a5 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2057,6 +2057,15 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
     setOperationAction(ISD::READ_REGISTER, MVT::i128, Custom);
     setOperationAction(ISD::WRITE_REGISTER, MVT::i128, Custom);
   }
+
+  if (VT.isInteger()) {
+    // Let common code emit inverted variants of compares we do support.
+    setCondCodeAction(ISD::SETNE, VT, Expand);
+    setCondCodeAction(ISD::SETLE, VT, Expand);
+    setCondCodeAction(ISD::SETLT, VT, Expand);
+    setCondCodeAction(ISD::SETULE, VT, Expand);
+    setCondCodeAction(ISD::SETULT, VT, Expand);
+  }
 }
 
 bool AArch64TargetLowering::shouldExpandGetActiveLaneMask(EVT ResVT,
@@ -2581,31 +2590,21 @@ unsigned AArch64TargetLowering::ComputeNumSignBitsForTargetNode(
   unsigned VTBits = VT.getScalarSizeInBits();
   unsigned Opcode = Op.getOpcode();
   switch (Opcode) {
-    case AArch64ISD::CMEQ:
-    case AArch64ISD::CMGE:
-    case AArch64ISD::CMGT:
-    case AArch64ISD::CMHI:
-    case AArch64ISD::CMHS:
-    case AArch64ISD::FCMEQ:
-    case AArch64ISD::FCMGE:
-    case AArch64ISD::FCMGT:
-    case AArch64ISD::CMEQz:
-    case AArch64ISD::CMGEz:
-    case AArch64ISD::CMGTz:
-    case AArch64ISD::CMLEz:
-    case AArch64ISD::CMLTz:
-    case AArch64ISD::FCMEQz:
-    case AArch64ISD::FCMGEz:
-    case AArch64ISD::FCMGTz:
-    case AArch64ISD::FCMLEz:
-    case AArch64ISD::FCMLTz:
-      // Compares return either 0 or all-ones
-      return VTBits;
-    case AArch64ISD::VASHR: {
-      unsigned Tmp =
-          DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
-      return std::min<uint64_t>(Tmp + Op.getConstantOperandVal(1), VTBits);
-    }
+  case AArch64ISD::FCMEQ:
+  case AArch64ISD::FCMGE:
+  case AArch64ISD::FCMGT:
+  case AArch64ISD::FCMEQz:
+  case AArch64ISD::FCMGEz:
+  case AArch64ISD::FCMGTz:
+  case AArch64ISD::FCMLEz:
+  case AArch64ISD::FCMLTz:
+    // Compares return either 0 or all-ones
+    return VTBits;
+  case AArch64ISD::VASHR: {
+    unsigned Tmp =
+        DAG.ComputeNumSignBits(Op.getOperand(0), DemandedElts, Depth + 1);
+    return std::min<uint64_t>(Tmp + Op.getConstantOperandVal(1), VTBits);
+  }
   }
 
   return 1;
@@ -2812,19 +2811,9 @@ const char *AArch64TargetLowering::getTargetNodeName(unsigned Opcode) const {
     MAKE_CASE(AArch64ISD::VASHR)
     MAKE_CASE(AArch64ISD::VSLI)
     MAKE_CASE(AArch64ISD::VSRI)
-    MAKE_CASE(AArch64ISD::CMEQ)
-    MAKE_CASE(AArch64ISD::CMGE)
-    MAKE_CASE(AArch64ISD::CMGT)
-    MAKE_CASE(AArch64ISD::CMHI)
-    MAKE_CASE(AArch64ISD::CMHS)
     MAKE_CASE(AArch64ISD::FCMEQ)
     MAKE_CASE(AArch64ISD::FCMGE)
     MAKE_CASE(AArch64ISD::FCMGT)
-    MAKE_CASE(AArch64ISD::CMEQz)
-    MAKE_CASE(AArch64ISD::CMGEz)
-    MAKE_CASE(AArch64ISD::CMGTz)
-    MAKE_CASE(AArch64ISD::CMLEz)
-    MAKE_CASE(AArch64ISD::CMLTz)
     MAKE_CASE(AArch64ISD::FCMEQz)
     MAKE_CASE(AArch64ISD::FCMGEz)
     MAKE_CASE(AArch64ISD::FCMGTz)
@@ -15814,9 +15803,6 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
                                             SplatBitSize, HasAnyUndefs);
 
   bool IsZero = IsCnst && SplatValue == 0;
-  bool IsOne =
-      IsCnst && SrcVT.getScalarSizeInBits() == SplatBitSize && SplatValue == 1;
-  bool IsMinusOne = IsCnst && SplatValue.isAllOnes();
 
   if (SrcVT.getVectorElementType().isFloatingPoint()) {
     switch (CC) {
@@ -15863,50 +15849,7 @@ static SDValue EmitVectorComparison(SDValue LHS, SDValue RHS,
     }
   }
 
-  switch (CC) {
-  default:
-    return SDValue();
-  case AArch64CC::NE: {
-    SDValue Cmeq;
-    if (IsZero)
-      Cmeq = DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
-    else
-      Cmeq = DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
-    return DAG.getNOT(dl, Cmeq, VT);
-  }
-  case AArch64CC::EQ:
-    if (IsZero)
-      return DAG.getNode(AArch64ISD::CMEQz, dl, VT, LHS);
-    return DAG.getNode(AArch64ISD::CMEQ, dl, VT, LHS, RHS);
-  case AArch64CC::GE:
-    if (IsZero)
-      return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
-    return DAG.getNode(AArch64ISD::CMGE, dl, VT, LHS, RHS);
-  case AArch64CC::GT:
-    if (IsZero)
-      return DAG.getNode(AArch64ISD::CMGTz, dl, VT, LHS);
-    if (IsMinusOne)
-      return DAG.getNode(AArch64ISD::CMGEz, dl, VT, LHS);
-    return DAG.getNode(AArch64ISD::CMGT, dl, VT, LHS, RHS);
-  case AArch64CC::LE:
-    if (IsZero)
-      return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
-    return DAG.getNode(AArch64ISD::CMGE, dl, VT, RHS, LHS);
-  case AArch64CC::LS:
-    return DAG.getNode(AArch64ISD::CMHS, dl, VT, RHS, LHS);
-  case AArch64CC::LO:
-    return DAG.getNode(AArch64ISD::CMHI, dl, VT, RHS, LHS);
-  case AArch64CC::LT:
-    if (IsZero)
-      return DAG.getNode(AArch64ISD::CMLTz, dl, VT, LHS);
-    if (IsOne)
-      return DAG.getNode(AArch64ISD::CMLEz, dl, VT, LHS);
-    return DAG.getNode(AArch64ISD::CMGT, dl, VT, RHS, LHS);
-  case AArch64CC::HI:
-    return DAG.getNode(AArch64ISD::CMHI, dl, VT, LHS, RHS);
-  case AArch64CC::HS:
-    return DAG.getNode(AArch64ISD::CMHS, dl, VT, LHS, RHS);
-  }
+  return SDValue();
 }
 
 SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
@@ -15927,9 +15870,11 @@ SDValue AArch64TargetLowering::LowerVSETCC(SDValue Op,
   if (LHS.getValueType().getVectorElementType().isInteger()) {
     assert(LHS.getValueType() == RHS.getValueType());
     AArch64CC::CondCode AArch64CC = changeIntCCToAArch64CC(CC);
-    SDValue Cmp =
-        EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG);
-    return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
+    if (SDValue Cmp =
+            EmitVectorComparison(LHS, RHS, AArch64CC, false, CmpVT, dl, DAG))
+      return DAG.getSExtOrTrunc(Cmp, dl, Op.getValueType());
+
+    return Op;
   }
 
   // Lower isnan(x) | isnan(never-nan) to x != x.
@@ -18128,7 +18073,9 @@ static SDValue foldVectorXorShiftIntoCmp(SDNode *N, SelectionDAG &DAG,
   if (!ShiftAmt || ShiftAmt->getZExtValue() != ShiftEltTy.getSizeInBits() - 1)
     return SDValue();
 
-  return DAG.getNode(AArch64ISD::CMGEz, SDLoc(N), VT, Shift.getOperand(0));
+  SDLoc DL(N);
+  SDValue Zero = DAG.getConstant(0, DL, Shift.getValueType());
+  return DAG.getSetCC(DL, VT, Shift.getOperand(0), Zero, ISD::SETGE);
 }
 
 // Given a vecreduce_add node, detect the below pattern and convert it to the
@@ -18739,7 +18686,8 @@ static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {
 
   SDLoc DL(N);
   SDValue In = DAG.getNode(AArch64ISD::NVCAST, DL, HalfVT, Srl.getOperand(0));
-  SDValue CM = DAG.getNode(AArch64ISD::CMLTz, DL, HalfVT, In);
+  SDValue Zero = DAG.getConstant(0, DL, In.getValueType());
+  SDValue CM = DAG.getSetCC(DL, HalfVT, Zero, In, ISD::SETGT);
   return DAG.getNode(AArch64ISD::NVCAST, DL, VT, CM);
 }
 
@@ -25268,6 +25216,14 @@ static SDValue performSETCCCombine(SDNode *N,
   if (SDValue V = performOrXorChainCombine(N, DAG))
     return V;
 
+  EVT CmpVT = LHS.getValueType();
+
+  APInt SplatLHSVal;
+  if (CmpVT.isInteger() && Cond == ISD::SETGT &&
+      ISD::isConstantSplatVector(LHS.getNode(), SplatLHSVal) &&
+      SplatLHSVal.isOne())
+    return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, CmpVT), RHS, ISD::SETGE);
+
   return SDValue();
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index bc0c3a832bb28..ba275e18fa126 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -241,21 +241,11 @@ enum NodeType : unsigned {
   VSRI,
 
   // Vector comparisons
-  CMEQ,
-  CMGE,
-  CMGT,
-  CMHI,
-  CMHS,
   FCMEQ,
   FCMGE,
   FCMGT,
 
   // Vector zero comparisons
-  CMEQz,
-  CMGEz,
-  CMGTz,
-  CMLEz,
-  CMLTz,
   FCMEQz,
   FCMGEz,
   FCMGTz,
diff --git a/llvm/lib/Target/AArch64/AArch64InstrFormats.td b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
index 255cd0ec5840c..6d8b84ea4239c 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrFormats.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrFormats.td
@@ -7086,7 +7086,7 @@ multiclass SIMD_FP8_CVTL<bits<2>sz, string asm, ValueType dty, SDPatternOperator
 class BaseSIMDCmpTwoVector<bit Q, bit U, bits<2> size, bits<2> size2,
                            bits<5> opcode, RegisterOperand regtype, string asm,
                            string kind, string zero, ValueType dty,
-                           ValueType sty, SDNode OpNode>
+                           ValueType sty, SDPatternOperator OpNode>
   : I<(outs regtype:$Rd), (ins regtype:$Rn), asm,
       "{\t$Rd" # kind # ", $Rn" # kind # ", #" # zero #
       "|" # kind # "\t$Rd, $Rn, #" # zero # "}", "",
@@ -7110,7 +7110,7 @@ class BaseSIMDCmpTwoVector<bit Q, bit U, bits<2> size, bits<2> size2,
 
 // Comparisons support all element sizes, except 1xD.
 multiclass SIMDCmpTwoVector<bit U, bits<5> opc, string asm,
-                            SDNode OpNode> {
+                            SDPatternOperator OpNode> {
   def v8i8rz  : BaseSIMDCmpTwoVector<0, U, 0b00, 0b00, opc, V64,
                                      asm, ".8b", "0",
                                      v8i8, v8i8, OpNode>;
@@ -7981,7 +7981,7 @@ multiclass SIMDCmpTwoScalarD<bit U, bits<5> opc, string asm,
                              SDPatternOperator OpNode> {
   def v1i64rz  : BaseSIMDCmpTwoScalar<U, 0b11, 0b00, opc, FPR64, asm, "0">;
 
-  def : Pat<(v1i64 (OpNode FPR64:$Rn)),
+  def : Pat<(v1i64 (OpNode v1i64:$Rn)),
             (!cast<Instruction>(NAME # v1i64rz) FPR64:$Rn)>;
 }
 
diff --git a/llvm/lib/Target/AArch64/AArch64InstrInfo.td b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
index 6c61e3a613f6f..a676e07f23dc3 100644
--- a/llvm/lib/Target/AArch64/AArch64InstrInfo.td
+++ b/llvm/lib/Target/AArch64/AArch64InstrInfo.td
@@ -846,23 +846,35 @@ def AArch64vsri : SDNode<"AArch64ISD::VSRI", SDT_AArch64vshiftinsert>;
 
 def AArch64bsp: SDNode<"AArch64ISD::BSP", SDT_AArch64trivec>;
 
-def AArch64cmeq: SDNode<"AArch64ISD::CMEQ", SDT_AArch64binvec>;
-def AArch64cmge: SDNode<"AArch64ISD::CMGE", SDT_AArch64binvec>;
-def AArch64cmgt: SDNode<"AArch64ISD::CMGT", SDT_AArch64binvec>;
-def AArch64cmhi: SDNode<"AArch64ISD::CMHI", SDT_AArch64binvec>;
-def AArch64cmhs: SDNode<"AArch64ISD::CMHS", SDT_AArch64binvec>;
+def AArch64cmeq : PatFrag<(ops node:$lhs, node:$rhs),
+                          (setcc node:$lhs, node:$rhs, SETEQ)>;
+def AArch64cmge : PatFrag<(ops node:$lhs, node:$rhs),
+                          (setcc node:$lhs, node:$rhs, SETGE)>;
+def AArch64cmgt : PatFrag<(ops node:$lhs, node:$rhs),
+                          (setcc node:$lhs, node:$rhs, SETGT)>;
+def AArch64cmhi : PatFrag<(ops node:$lhs, node:$rhs),
+                          (setcc node:$lhs, node:$rhs, SETUGT)>;
+def AArch64cmhs : PatFrag<(ops node:$lhs, node:$rhs),
+                          (setcc node:$lhs, node:$rhs, SETUGE)>;
 
 def AArch64fcmeq: SDNode<"AArch64ISD::FCMEQ", SDT_AArch64fcmp>;
 def AArch64fcmge: SDNode<"AArch64ISD::FCMGE", SDT_AArch64fcmp>;
 def AArch64fcmgt: SDNode<"AArch64ISD::FCMGT", SDT_AArch64fcmp>;
 
-def AArch64cmeqz: SDNode<"AArch64ISD::CMEQz", SDT_AArch64unvec>;
-def AArch64cmgez: SDNode<"AArch64ISD::CMGEz", SDT_AArch64unvec>;
-def AArch64cmgtz: SDNode<"AArch64ISD::CMGTz", SDT_AArch64unvec>;
-def AArch64cmlez: SDNode<"AArch64ISD::CMLEz", SDT_AArch64unvec>;
-def AArch64cmltz: SDNode<"AArch64ISD::CMLTz", SDT_AArch64unvec>;
+def AArch64cmeqz : PatFrag<(ops node:$lhs),
+                           (setcc node:$lhs, immAllZerosV, SETEQ)>;
+def AArch64cmgez : PatFrags<(ops node:$lhs),
+                            [(setcc node:$lhs, immAllZerosV, SETGE),
+                             (setcc node:$lhs, immAllOnesV, SETGT)]>;
+def AArch64cmgtz : PatFrag<(ops node:$lhs),
+                           (setcc node:$lhs, immAllZerosV, SETGT)>;
+def AArch64cmlez : PatFrag<(ops node:$lhs),
+                           (setcc immAllZerosV, node:$lhs, SETGE)>;
+def AArch64cmltz : PatFrag<(ops node:$lhs),
+                           (setcc immAllZerosV, node:$lhs, SETGT)>;
+
 def AArch64cmtst : PatFrag<(ops node:$LHS, node:$RHS),
-                        (vnot (AArch64cmeqz (and node:$LHS, node:$RHS)))>;
+                           (vnot (AArch64cmeqz (and node:$LHS, node:$RHS)))>;
 
 def AArch64fcmeqz: SDNode<"AArch64ISD::FCMEQz", SDT_AArch64fcmpz>;
 def AArch64fcmgez: SDNode<"AArch64ISD::FCMGEz", SDT_AArch64fcmpz>;
@@ -5671,7 +5683,7 @@ defm CMHI    : SIMDThreeSameVector<1, 0b00110, "cmhi", AArch64cmhi>;
 defm CMHS    : SIMDThreeSameVector<1, 0b00111, "cmhs", AArch64cmhs>;
 defm CMTST   : SIMDThreeSameVector<0, 0b10001, "cmtst", AArch64cmtst>;
 foreach VT = [ v8i8, v16i8, v4i16, v8i16, v2i32, v4i32, v2i64 ] in {
-def : Pat<(vnot (AArch64cmeqz VT:$Rn)), (!cast<Instruction>("CMTST"#VT) VT:$Rn, VT:$Rn)>;
+def : Pat<(VT (vnot (AArch64cmeqz VT:$Rn))), (!cast<Instruction>("CMTST"#VT) VT:$Rn, VT:$Rn)>;
 }
 defm FABD    : SIMDThreeSameVectorFP<1,1,0b010,"fabd", int_aarch64_neon_fabd>;
 let Predicates = [HasNEON] in {
diff --git a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
index f0c9dccb21d84..c7a423f2e4f8d 100644
--- a/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
+++ b/llvm/test/CodeGen/AArch64/aarch64-dup-ext.ll
@@ -352,17 +352,16 @@ define void @typei1_orig(i64 %a, ptr %p, ptr %q) {
 ;
 ; CHECK-GI-LABEL: typei1_orig:
 ; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    ldr q1, [x2]
+; CHECK-GI-NEXT:    ldr q0, [x2]
 ; CHECK-GI-NEXT:    cmp x0, #0
-; CHECK-GI-NEXT:    movi v0.2d, #0xffffffffffffffff
 ; CHECK-GI-NEXT:    cset w8, gt
-; CHECK-GI-NEXT:    neg v1.8h, v1.8h
-; CHECK-GI-NEXT:    dup v2.8h, w8
-; CHECK-GI-NEXT:    mvn v0.16b, v0.16b
-; CHECK-GI-NEXT:    mul v1.8h, v1.8h, v2.8h
-; CHECK-GI-NEXT:    cmeq v1.8h, v1.8h, #0
+; CHECK-GI-NEXT:    neg v0.8h, v0.8h
+; CHECK-GI-NEXT:    dup v1.8h, w8
+; CHECK-GI-NEXT:    mul v0.8h, v0.8h, v1.8h
+; CHECK-GI-NEXT:    movi v1.2d, #0xffffffffffffffff
+; CHECK-GI-NEXT:    cmtst v0.8h, v0.8h, v0.8h
 ; CHECK-GI-NEXT:    mvn v1.16b, v1.16b
-; CHECK-GI-NEXT:    uzp1 v0.16b, v1.16b, v0.16b
+; CHECK-GI-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
 ; CHECK-GI-NEXT:    shl v0.16b, v0.16b, #7
 ; CHECK-GI-NEXT:    sshr v0.16b, v0.16b, #7
 ; CHECK-GI-NEXT:    str q0, [x1]
diff --git a/llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll b/llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll
index 81770a4ebdd4d..c834ca772b6ac 100644
--- a/llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll
+++ b/llvm/test/CodeGen/AArch64/fptosi-sat-vector.ll
@@ -2382,11 +2382,11 @@ define <2 x i1> @test_signed_v2f64_v2i1(<2 x double> %f) {
 ; CHECK-GI-LABEL: test_signed_v2f64_v2i1:
 ; CHECK-GI:       // %bb.0:
 ; CHECK-GI-NEXT:    fcvtzs v0.2d, v0.2d
-; CHECK-GI-NEXT:    movi v2.2d, #0xffffffffffffffff
 ; CHECK-GI-NEXT:    cmlt v1.2d, v0.2d, #0
 ; CHECK-GI-NEXT:    and v0.16b, v0.16b, v1.16b
-; CHECK-GI-NEXT:    cmgt v1.2d, v0.2d, v2.2d
-; CHECK-GI-NEXT:    bif v0.16b, v2.16b, v1.16b
+; CHECK-GI-NEXT:    movi v1.2d, #0xffffffffffffffff
+; CHECK-GI-NEXT:    cmge v2.2d, v0.2d, #0
+; CHECK-GI-NEXT:    bif v0.16b, v1.16b, v2.16b
 ; CHECK-GI-NEXT:    xtn v0.2s, v0.2d
 ; CHECK-GI-NEXT:    ret
     %x = call <2 x i1> @llvm.fptosi.sat.v2f64.v2i1(<2 x double> %f)
diff --git a/llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll b/llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll
index f6dbf5251fc27..fb65a748c865f 100644
--- a/llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/neon-bitwise-instructions.ll
@@ -1499,8 +1499,7 @@ define <8 x i8> @vselect_cmpz_ne(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
 ;
 ; CHECK-GI-LABEL: vselect_cmpz_ne:
 ; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    cmeq v0.8b, v0.8b, #0
-; CHECK-GI-NEXT:    mvn v0.8b, v0.8b
+; CHECK-GI-NEXT:    cmtst v0.8b, v0.8b, v0.8b
 ; CHECK-GI-NEXT:    bsl v0.8b, v1.8b, v2.8b
 ; CHECK-GI-NEXT:    ret
   %cmp = icmp ne <8 x i8> %a, zeroinitializer
@@ -1533,17 +1532,10 @@ define <8 x i8> @vselect_tst(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
 }
 
 define <8 x i8> @sext_tst(<8 x i8> %a, <8 x i8> %b, <8 x i8> %c) {
-; CHECK-SD-LABEL: sext_tst:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    cmtst v0.8b, v0.8b, v1.8b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: sext_tst:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    and v0.8b, v0.8b, v1.8b
-; CHECK-GI-NEXT:    cmeq v0.8b, v0.8b, #0
-; CHECK-GI-NEXT:    mvn v0.8b, v0.8b
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: sext_tst:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmtst v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    ret
   %tmp3 = and <8 x i8> %a, %b
   %tmp4 = icmp ne <8 x i8> %tmp3, zeroinitializer
   %d = sext <8 x i1> %tmp4 to <8 x i8>
diff --git a/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll b/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
index 8f7d5dd5588b9..2c2cb72112879 100644
--- a/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
+++ b/llvm/test/CodeGen/AArch64/neon-compare-instructions.ll
@@ -738,17 +738,10 @@ define <2 x i64> @cmls2xi64(<2 x i64> %A, <2 x i64> %B) {
 }
 
 define <8 x i8> @cmtst8xi8(<8 x i8> %A, <8 x i8> %B) {
-; CHECK-SD-LABEL: cmtst8xi8:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    cmtst v0.8b, v0.8b, v1.8b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: cmtst8xi8:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    and v0.8b, v0.8b, v1.8b
-; CHECK-GI-NEXT:    cmeq v0.8b, v0.8b, #0
-; CHECK-GI-NEXT:    mvn v0.8b, v0.8b
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: cmtst8xi8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmtst v0.8b, v0.8b, v1.8b
+; CHECK-NEXT:    ret
   %tmp3 = and <8 x i8> %A, %B
   %tmp4 = icmp ne <8 x i8> %tmp3, zeroinitializer
   %tmp5 = sext <8 x i1> %tmp4 to <8 x i8>
@@ -756,17 +749,10 @@ define <8 x i8> @cmtst8xi8(<8 x i8> %A, <8 x i8> %B) {
 }
 
 define <16 x i8> @cmtst16xi8(<16 x i8> %A, <16 x i8> %B) {
-; CHECK-SD-LABEL: cmtst16xi8:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    cmtst v0.16b, v0.16b, v1.16b
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: cmtst16xi8:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    and v0.16b, v0.16b, v1.16b
-; CHECK-GI-NEXT:    cmeq v0.16b, v0.16b, #0
-; CHECK-GI-NEXT:    mvn v0.16b, v0.16b
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: cmtst16xi8:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmtst v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    ret
   %tmp3 = and <16 x i8> %A, %B
   %tmp4 = icmp ne <16 x i8> %tmp3, zeroinitializer
   %tmp5 = sext <16 x i1> %tmp4 to <16 x i8>
@@ -774,17 +760,10 @@ define <16 x i8> @cmtst16xi8(<16 x i8> %A, <16 x i8> %B) {
 }
 
 define <4 x i16> @cmtst4xi16(<4 x i16> %A, <4 x i16> %B) {
-; CHECK-SD-LABEL: cmtst4xi16:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    cmtst v0.4h, v0.4h, v1.4h
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: cmtst4xi16:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    and v0.8b, v0.8b, v1.8b
-; CHECK-GI-NEXT:    cmeq v0.4h, v0.4h, #0
-; CHECK-GI-NEXT:    mvn v0.8b, v0.8b
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: cmtst4xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmtst v0.4h, v0.4h, v1.4h
+; CHECK-NEXT:    ret
   %tmp3 = and <4 x i16> %A, %B
   %tmp4 = icmp ne <4 x i16> %tmp3, zeroinitializer
   %tmp5 = sext <4 x i1> %tmp4 to <4 x i16>
@@ -792,17 +771,10 @@ define <4 x i16> @cmtst4xi16(<4 x i16> %A, <4 x i16> %B) {
 }
 
 define <8 x i16> @cmtst8xi16(<8 x i16> %A, <8 x i16> %B) {
-; CHECK-SD-LABEL: cmtst8xi16:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    cmtst v0.8h, v0.8h, v1.8h
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: cmtst8xi16:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    and v0.16b, v0.16b, v1.16b
-; CHECK-GI-NEXT:    cmeq v0.8h, v0.8h, #0
-; CHECK-GI-NEXT:    mvn v0.16b, v0.16b
-; CHECK-GI-NEXT:    ret
+; CHECK-LABEL: cmtst8xi16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmtst v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    ret
   %tmp3 = and <8 x i16> %A, %B
   %tmp4 = icmp ne <8 x i16> %tmp3, zeroinitializer
   %tmp5 = sext <8 x i1> %tmp4 to <8 x i16>
@@ -810,17 +782,10 @@ define <8 x i16> @cmtst8xi16(<8 x i16> %A, <8 x i16> %B) {
 }
 
 define <2 x i32> @cmtst2xi32(<2 x i32> %A, <2 x i32> %B) {
-; CHECK-SD-LABEL: cmtst2xi32:
-; CHECK-SD:       // %bb.0:
-; CHECK-SD-NEXT:    cmtst v0.2s, v0.2s, v1.2s
-; CHECK-SD-NEXT:    ret
-;
-; CHECK-GI-LABEL: cmtst2xi32:
-; CHECK-GI:       // %bb.0:
-; CHECK-GI-NEXT:    and v0.8b, v0.8b, v1.8b
-; CHECK-GI-NEXT:    cmeq v0.2s...
[truncated]

Comment on lines 25221 to 25225
APInt SplatLHSVal;
if (CmpVT.isInteger() && Cond == ISD::SETGT &&
ISD::isConstantSplatVector(LHS.getNode(), SplatLHSVal) &&
SplatLHSVal.isOne())
return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, CmpVT), RHS, ISD::SETGE);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this combine because it proved awkward to detect splat(1) across all the NEON types within a PatFrag. I think this is again because we're performing isel of constants during legalisation when it would be better to use splat_vector universally. Please let me know if you've a better solution, otherwise I'm happy to circle back after making the DAG more amenable.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth adding a comment explaining why you can't currently use a TableGen pattern for this. But seems fine.

@paulwalker-arm
Copy link
Collaborator Author

ping

Comment on lines 25221 to 25225
APInt SplatLHSVal;
if (CmpVT.isInteger() && Cond == ISD::SETGT &&
ISD::isConstantSplatVector(LHS.getNode(), SplatLHSVal) &&
SplatLHSVal.isOne())
return DAG.getSetCC(DL, VT, DAG.getConstant(0, DL, CmpVT), RHS, ISD::SETGE);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably worth adding a comment explaining why you can't currently use a TableGen pattern for this. But seems fine.

Copy link
Collaborator

@efriedma-quic efriedma-quic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@paulwalker-arm paulwalker-arm merged commit b0b97e3 into llvm:main Apr 4, 2025
11 checks passed
@paulwalker-arm paulwalker-arm deleted the sve-vls-masked-loads branch April 4, 2025 11:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants