Skip to content

Reapply "[DAGCombiner] Add support for scalarising extracts of a vector setcc (#117566)" #118823

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 9, 2024

Conversation

david-arm
Copy link
Contributor

@david-arm david-arm commented Dec 5, 2024

[Reverts d57892a]

For IR like this:

%icmp = icmp ult <4 x i32> %a, splat (i32 5)
%res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

%ext = extractelement <4 x i32> %a, i32 1
%res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.

For IR like this:

  %icmp = icmp ult <4 x i32> %a, splat (i32 5)
  %res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

  %ext = extractelement <4 x i32> %a, i32 1
  %res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.
@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-backend-risc-v

@llvm/pr-subscribers-backend-x86

Author: David Sherwood (david-arm)

Changes

For IR like this:

%icmp = icmp ult <4 x i32> %a, splat (i32 5)
%res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

%ext = extractelement <4 x i32> %a, i32 1
%res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.


Patch is 20.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118823.diff

10 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+27-16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+17)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+4)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll (+24-22)
  • (added) llvm/test/CodeGen/AArch64/extract-vector-cmp.ll (+204)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+26)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 48018ac29bd089..1aab0fce4e1e53 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22751,16 +22751,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
 
 /// Transform a vector binary operation into a scalar binary operation by moving
 /// the math/logic after an extract element of a vector.
-static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
-                                       const SDLoc &DL, bool LegalOperations) {
+static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
+                                       const SDLoc &DL, bool LegalTypes) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   SDValue Vec = ExtElt->getOperand(0);
   SDValue Index = ExtElt->getOperand(1);
   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
+  unsigned Opc = Vec.getOpcode();
+  if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
       Vec->getNumValues() != 1)
     return SDValue();
 
+  EVT ResVT = ExtElt->getValueType(0);
+  if (Opc == ISD::SETCC &&
+      (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
+    return SDValue();
+
   // Targets may want to avoid this to prevent an expensive register transfer.
   if (!TLI.shouldScalarizeBinop(Vec))
     return SDValue();
@@ -22771,19 +22777,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
   SDValue Op0 = Vec.getOperand(0);
   SDValue Op1 = Vec.getOperand(1);
   APInt SplatVal;
-  if (isAnyConstantBuildVector(Op0, true) ||
-      ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
-      isAnyConstantBuildVector(Op1, true) ||
-      ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
-    // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
-    // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
-    EVT VT = ExtElt->getValueType(0);
-    SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
-    SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
-    return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
-  }
+  if (!isAnyConstantBuildVector(Op0, true) &&
+      !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
+      !isAnyConstantBuildVector(Op1, true) &&
+      !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
+    return SDValue();
 
-  return SDValue();
+  // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
+  // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
+  if (Opc == ISD::SETCC) {
+    EVT OpVT = Op0.getValueType().getVectorElementType();
+    Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
+    Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
+    return DAG.getSetCC(DL, ResVT, Op0, Op1,
+                        cast<CondCodeSDNode>(Vec->getOperand(2))->get());
+  }
+  Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
+  Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
+  return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
 }
 
 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23016,7 +23027,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
     }
   }
 
-  if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
+  if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
     return BO;
 
   if (VecVT.isScalableVector())
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 986d69e6c7a9e0..f8021c49a138ac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SELECT_CC:    SplitRes_SELECT_CC(N, Lo, Hi); break;
   case ISD::UNDEF:        SplitRes_UNDEF(N, Lo, Hi); break;
   case ISD::FREEZE:       SplitRes_FREEZE(N, Lo, Hi); break;
+  case ISD::SETCC:        ExpandIntRes_SETCC(N, Lo, Hi); break;
 
   case ISD::BITCAST:            ExpandRes_BITCAST(N, Lo, Hi); break;
   case ISD::BUILD_PAIR:         ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
@@ -3316,6 +3317,22 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
   }
 }
 
+void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
+  SDLoc DL(N);
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  EVT NewVT = getSetCCResultType(LHS.getValueType());
+
+  // Taking the same approach as ScalarizeVecRes_SETCC
+  SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));
+
+  ISD::NodeType ExtendCode =
+      TargetLowering::getExtendForContent(TLI.getBooleanContents(NewVT));
+  Res = DAG.getExtOrTrunc(Res, DL, N->getValueType(0), ExtendCode);
+  SplitInteger(Res, Lo, Hi);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
                                            SDValue &Lo, SDValue &Hi) {
   SDLoc DL(N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 1703149aca7463..571a710cc92a34 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void ExpandIntRes_MINMAX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_CMP               (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_SETCC             (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index cb0b9e965277aa..d51b36f7e49946 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 
   bool softPromoteHalfType() const override { return true; }
+
+  bool shouldScalarizeBinop(SDValue VecOp) const override {
+    return VecOp.getOpcode() == ISD::SETCC;
+  }
 };
 
 namespace AArch64 {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2b2daad1b8987..cfa9620c419de7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index c765d2b1ab95bc..7712570869ff6c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c18a4ac9acb1e4..8d693ac64321ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
index 5a5dee0b53d439..4cb1d5b2fb345d 100644
--- a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
+++ b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
@@ -5,7 +5,7 @@
 
 declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)
 
-define fastcc i8 @allocno_reload_assign() {
+define fastcc i8 @allocno_reload_assign(ptr %p) {
 ; CHECK-LABEL: allocno_reload_assign:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmov d0, xzr
@@ -14,8 +14,8 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    uzp1 p0.s, p0.s, p0.s
 ; CHECK-NEXT:    uzp1 p0.h, p0.h, p0.h
-; CHECK-NEXT:    uzp1 p0.b, p0.b, p0.b
-; CHECK-NEXT:    mov z0.b, p0/z, #1 // =0x1
+; CHECK-NEXT:    uzp1 p8.b, p0.b, p0.b
+; CHECK-NEXT:    mov z0.b, p8/z, #1 // =0x1
 ; CHECK-NEXT:    fmov w8, s0
 ; CHECK-NEXT:    mov z0.b, #0 // =0x0
 ; CHECK-NEXT:    uunpklo z1.h, z0.b
@@ -30,34 +30,35 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    punpklo p1.h, p0.b
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    punpklo p2.h, p1.b
-; CHECK-NEXT:    punpkhi p3.h, p1.b
+; CHECK-NEXT:    punpkhi p4.h, p1.b
 ; CHECK-NEXT:    uunpklo z0.d, z2.s
 ; CHECK-NEXT:    uunpkhi z1.d, z2.s
-; CHECK-NEXT:    punpklo p5.h, p0.b
+; CHECK-NEXT:    punpklo p6.h, p0.b
 ; CHECK-NEXT:    uunpklo z2.d, z3.s
 ; CHECK-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEXT:    punpkhi p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    uunpklo z4.d, z5.s
 ; CHECK-NEXT:    uunpkhi z5.d, z5.s
 ; CHECK-NEXT:    uunpklo z6.d, z7.s
 ; CHECK-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEXT:    punpklo p0.h, p2.b
-; CHECK-NEXT:    punpkhi p1.h, p2.b
-; CHECK-NEXT:    punpklo p2.h, p3.b
-; CHECK-NEXT:    punpkhi p3.h, p3.b
-; CHECK-NEXT:    punpklo p4.h, p5.b
-; CHECK-NEXT:    punpkhi p5.h, p5.b
-; CHECK-NEXT:    punpklo p6.h, p7.b
-; CHECK-NEXT:    punpkhi p7.h, p7.b
+; CHECK-NEXT:    punpklo p1.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    punpklo p3.h, p4.b
+; CHECK-NEXT:    punpkhi p4.h, p4.b
+; CHECK-NEXT:    punpklo p5.h, p6.b
+; CHECK-NEXT:    punpkhi p6.h, p6.b
+; CHECK-NEXT:    punpklo p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:  .LBB0_1: // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    st1b { z0.d }, p0, [z16.d]
-; CHECK-NEXT:    st1b { z1.d }, p1, [z16.d]
-; CHECK-NEXT:    st1b { z2.d }, p2, [z16.d]
-; CHECK-NEXT:    st1b { z3.d }, p3, [z16.d]
-; CHECK-NEXT:    st1b { z4.d }, p4, [z16.d]
-; CHECK-NEXT:    st1b { z5.d }, p5, [z16.d]
-; CHECK-NEXT:    st1b { z6.d }, p6, [z16.d]
-; CHECK-NEXT:    st1b { z7.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z0.d }, p1, [z16.d]
+; CHECK-NEXT:    st1b { z1.d }, p2, [z16.d]
+; CHECK-NEXT:    st1b { z2.d }, p3, [z16.d]
+; CHECK-NEXT:    st1b { z3.d }, p4, [z16.d]
+; CHECK-NEXT:    st1b { z4.d }, p5, [z16.d]
+; CHECK-NEXT:    st1b { z5.d }, p6, [z16.d]
+; CHECK-NEXT:    st1b { z6.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z7.d }, p0, [z16.d]
+; CHECK-NEXT:    str p8, [x0]
 ; CHECK-NEXT:    b .LBB0_1
   br label %1
 
@@ -66,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
   %constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
   %constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
   call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
+  store <vscale x 16 x i1> %constexpr, ptr %p, align 16
   br label %1
 }
 
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
new file mode 100644
index 00000000000000..12bd2db2297d77
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
@@ -0,0 +1,204 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+
+define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #5
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat (i32 5)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_lhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #7
+; CHECK-NEXT:    cset w0, hi
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> splat(i32 7), %a
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_vec_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #234
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, <i32 5, i32 234, i32 -1, i32 7>
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_fcmp_v4f32_const_splat_rhs(<4 x float> %a) {
+; CHECK-LABEL: extract_fcmp_v4f32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov s0, v0.s[1]
+; CHECK-NEXT:    fmov s1, #4.00000000
+; CHECK-NEXT:    fcmp s0, s1
+; CHECK-NEXT:    cset w0, lt
+; CHECK-NEXT:    ret
+  %fcmp = fcmp ult <4 x float> %a, splat(float 4.0e+0)
+  %ext = extractelement <4 x i1> %fcmp, i32 1
+  ret i1 %ext
+}
+
+; Tests the code in ExpandIntRes_SETCC
+define i128 @extract_icmp_v1i128(ptr %p) {
+; CHECK-LABEL: extract_icmp_v1i128:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldp x9, x8, [x0]
+; CHECK-NEXT:    mov x1, xzr
+; CHECK-NEXT:    orr x8, x9, x8
+; CHECK-NEXT:    cmp x8, #0
+; CHECK-NEXT:    cset w0, eq
+; CHECK-NEXT:    ret
+  %load = load <1 x i128>, ptr %p, align 16
+  %cmp = icmp eq <1 x i128> %load, zeroinitializer
+  %sext = sext <1 x i1> %cmp to <1 x i128>
+  %res = extractelement <1 x i128> %sext, i32 0
+  ret i128 %res
+}
+
+define void @vector_loop_with_icmp(ptr nocapture noundef writeonly %dest) {
+; CHECK-LABEL: vector_loop_with_icmp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.d, #0, #1
+; CHECK-NEXT:    mov w8, #2 // =0x2
+; CHECK-NEXT:    mov w9, #16 // =0x10
+; CHECK-NEXT:    dup v1.2d, x8
+; CHECK-NEXT:    add x8, x0, #4
+; CHECK-NEXT:    mov w10, #1 // =0x1
+; CHECK-NEXT:    b .LBB5_2
+; CHECK-NEXT:  .LBB5_1: // %pred.store.continue6
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    add v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    subs x9, x9, #2
+; CHECK-NEXT:    add x8, x8, #8
+; CHECK-NEXT:    b.eq .LBB5_6
+; CHECK-NEXT:  .LBB5_2: // %vector.body
+; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    fmov x11, d0
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_4
+; CHECK-NEXT:  // %bb.3: // %pred.store.if
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    stur w10, [x8, #-4]
+; CHECK-NEXT:  .LBB5_4: // %pred.store.continue
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    mov x11, v0.d[1]
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_1
+; CHECK-NEXT:  // %bb.5: // %pred.store.if5
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    str w10, [x8]
+; CHECK-NEXT:    b .LBB5_1
+; CHECK-NEXT:  .LBB5_6: // %for.cond.cleanup
+; CHECK-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %pred.store.continue6 ]
+  %vec.ind = phi <2 x i64> [ <i64 0, i64 1>, %entry ], [ %vec.ind.next, %pred.store.continue6 ]
+  %vec.cmp = icmp ult <2 x i64> %vec.ind, <i64 15, i64 15>
+  %c0 = extractelement <2 x i1> %vec.cmp, i64 0
+  br i1 %c0, label %pred.store.if, label %pred.store.continue
+
+pred.store.if:
+  %arrayidx = getelementptr inbounds i32, ptr %dest, i64 %index
+  store i32 1, ptr %arrayidx, align 4
+  br label %pred.store.continue
+
+pred.store.continue:
+  %c1 = extractelement <2 x i1> %vec.cmp, i64 1
+  br i1 %c1, label %pred.store.if5, label %pred.store.continue6
+
+pred.store.if5:
+  %indexp1 = or disjoint i64 %index, 1
+  %arrayidx2 = getelementptr inbounds i32, ptr %dest, i64 %indexp1
+  store i32 1, ptr %arrayidx2, align 4
+  br label %pred.store.continue6
+
+pred.store.continue6:
+  %index.next = add i64 %index, 2
+  %vec.ind.next = add <2 x i64> %vec.ind, <i64 2, i64 2>
+  %index.cmp = icmp eq i64 %index.next, 16
+  br i1 %index.cmp, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:
+  ret void
+}
+
+
+; Negative tests
+
+define i1 @extract_icmp_v4i32_splat_rhs(<4 x i32> %a, i32 %b) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v1.4s, w0
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %ins = insertelement <4 x i32> poison, i32 %b, i32 0
+  %splat = shufflevector <4 x i32> %ins, <4 x i32> poison, <4 x i32> zeroinitializer
+  %icmp = icmp ult <4 x i32> %a, %splat
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_mul_use(<4 x i32> %a, ptr %p) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_mul_use:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v1.4s, #235
+; CHECK-NEXT:    adrp x9, .LCPI7_0
+; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI7_0]
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v1.4h, v0.4s
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    umov w9, v1.h[1]
+; CHECK-NEXT:    fmov w10, s0
+; CHECK-NEXT:    and w0, w9, #0x1
+; CHECK-NEXT:    strb w10, [x8]
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 235)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  store <4 x i1> %icmp, ptr %p, align 4
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_unknown_idx(<4 x i32> %a, i32 %c) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_unknown_idx:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    movi v1.4s, #127
+; CHECK-NEXT:    add x8, sp, #8
+; CHECK-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-NEXT:    bfi x8, x0, #1, #2
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    str d0, [sp, #8]
+; CHECK-NEXT:    ldrh w8, [x8]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 127)
+  %ext = extractelement <4 x i1> %icmp, i32 %c
+  ret i1 %ext
+}
+
diff --git a/llvm/test/CodeGen/X86/vselect.ll b/llvm/test/CodeGen/X86/vselect.ll
index 9acd995d612c31..be6ee8f6899584 100644
--- a/llvm/test/CodeGen/X86/vselect.ll
+++ b/llvm/test/CodeGen/X86/vselect.ll
@@ -796,3 +796,29 @@ define i64 @vselect_any_extend_vector_inreg_crash(ptr %x) {
   ret i64 %4
 }
 
+; Tests the scalarizeBinOp code in DAGCombiner
+define void @scalarize_binop(<1 x i1> %a) {
+; SSE-LABEL: scalarize_binop:
+; SSE:       # %bb.0: # %bb0
+; SSE-NEXT:    .p2align 4
+; SSE-NEXT:  .LBB35_1: # %bb1
+; SSE-NEXT:    # =>This Inner Loop Header: Depth=1
+; SSE-NEXT:    jmp .LBB35_1
+;
+; AVX-LABEL: scalarize_binop:
+; AVX:       # %bb.0: # %bb0
+; AVX-NEXT:    .p2align 4
+; AVX-NEXT:  .LBB35_1: # %bb1
+; AVX-NEXT:    # =>This Inner Loop Header: Depth=1
+; AVX-NEXT:    jmp .LBB35_1
+bb0:
+  br label %bb1
+
+bb1:
+  %b = select <1 x i1> %a, <1 x i1> zeroinitializer, <1 x i1> splat (i1 true)
+  br label %bb2
+
+bb2:
+  %c...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-backend-webassembly

Author: David Sherwood (david-arm)

Changes

For IR like this:

%icmp = icmp ult <4 x i32> %a, splat (i32 5)
%res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

%ext = extractelement <4 x i32> %a, i32 1
%res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.


Patch is 20.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118823.diff

10 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+27-16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+17)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+4)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll (+24-22)
  • (added) llvm/test/CodeGen/AArch64/extract-vector-cmp.ll (+204)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+26)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 48018ac29bd089..1aab0fce4e1e53 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22751,16 +22751,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
 
 /// Transform a vector binary operation into a scalar binary operation by moving
 /// the math/logic after an extract element of a vector.
-static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
-                                       const SDLoc &DL, bool LegalOperations) {
+static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
+                                       const SDLoc &DL, bool LegalTypes) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   SDValue Vec = ExtElt->getOperand(0);
   SDValue Index = ExtElt->getOperand(1);
   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
+  unsigned Opc = Vec.getOpcode();
+  if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
       Vec->getNumValues() != 1)
     return SDValue();
 
+  EVT ResVT = ExtElt->getValueType(0);
+  if (Opc == ISD::SETCC &&
+      (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
+    return SDValue();
+
   // Targets may want to avoid this to prevent an expensive register transfer.
   if (!TLI.shouldScalarizeBinop(Vec))
     return SDValue();
@@ -22771,19 +22777,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
   SDValue Op0 = Vec.getOperand(0);
   SDValue Op1 = Vec.getOperand(1);
   APInt SplatVal;
-  if (isAnyConstantBuildVector(Op0, true) ||
-      ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
-      isAnyConstantBuildVector(Op1, true) ||
-      ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
-    // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
-    // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
-    EVT VT = ExtElt->getValueType(0);
-    SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
-    SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
-    return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
-  }
+  if (!isAnyConstantBuildVector(Op0, true) &&
+      !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
+      !isAnyConstantBuildVector(Op1, true) &&
+      !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
+    return SDValue();
 
-  return SDValue();
+  // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
+  // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
+  if (Opc == ISD::SETCC) {
+    EVT OpVT = Op0.getValueType().getVectorElementType();
+    Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
+    Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
+    return DAG.getSetCC(DL, ResVT, Op0, Op1,
+                        cast<CondCodeSDNode>(Vec->getOperand(2))->get());
+  }
+  Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
+  Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
+  return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
 }
 
 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23016,7 +23027,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
     }
   }
 
-  if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
+  if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
     return BO;
 
   if (VecVT.isScalableVector())
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 986d69e6c7a9e0..f8021c49a138ac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SELECT_CC:    SplitRes_SELECT_CC(N, Lo, Hi); break;
   case ISD::UNDEF:        SplitRes_UNDEF(N, Lo, Hi); break;
   case ISD::FREEZE:       SplitRes_FREEZE(N, Lo, Hi); break;
+  case ISD::SETCC:        ExpandIntRes_SETCC(N, Lo, Hi); break;
 
   case ISD::BITCAST:            ExpandRes_BITCAST(N, Lo, Hi); break;
   case ISD::BUILD_PAIR:         ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
@@ -3316,6 +3317,22 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
   }
 }
 
+void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
+  SDLoc DL(N);
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  EVT NewVT = getSetCCResultType(LHS.getValueType());
+
+  // Taking the same approach as ScalarizeVecRes_SETCC
+  SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));
+
+  ISD::NodeType ExtendCode =
+      TargetLowering::getExtendForContent(TLI.getBooleanContents(NewVT));
+  Res = DAG.getExtOrTrunc(Res, DL, N->getValueType(0), ExtendCode);
+  SplitInteger(Res, Lo, Hi);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
                                            SDValue &Lo, SDValue &Hi) {
   SDLoc DL(N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 1703149aca7463..571a710cc92a34 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void ExpandIntRes_MINMAX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_CMP               (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_SETCC             (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index cb0b9e965277aa..d51b36f7e49946 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 
   bool softPromoteHalfType() const override { return true; }
+
+  bool shouldScalarizeBinop(SDValue VecOp) const override {
+    return VecOp.getOpcode() == ISD::SETCC;
+  }
 };
 
 namespace AArch64 {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2b2daad1b8987..cfa9620c419de7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index c765d2b1ab95bc..7712570869ff6c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c18a4ac9acb1e4..8d693ac64321ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
index 5a5dee0b53d439..4cb1d5b2fb345d 100644
--- a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
+++ b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
@@ -5,7 +5,7 @@
 
 declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)
 
-define fastcc i8 @allocno_reload_assign() {
+define fastcc i8 @allocno_reload_assign(ptr %p) {
 ; CHECK-LABEL: allocno_reload_assign:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmov d0, xzr
@@ -14,8 +14,8 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    uzp1 p0.s, p0.s, p0.s
 ; CHECK-NEXT:    uzp1 p0.h, p0.h, p0.h
-; CHECK-NEXT:    uzp1 p0.b, p0.b, p0.b
-; CHECK-NEXT:    mov z0.b, p0/z, #1 // =0x1
+; CHECK-NEXT:    uzp1 p8.b, p0.b, p0.b
+; CHECK-NEXT:    mov z0.b, p8/z, #1 // =0x1
 ; CHECK-NEXT:    fmov w8, s0
 ; CHECK-NEXT:    mov z0.b, #0 // =0x0
 ; CHECK-NEXT:    uunpklo z1.h, z0.b
@@ -30,34 +30,35 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    punpklo p1.h, p0.b
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    punpklo p2.h, p1.b
-; CHECK-NEXT:    punpkhi p3.h, p1.b
+; CHECK-NEXT:    punpkhi p4.h, p1.b
 ; CHECK-NEXT:    uunpklo z0.d, z2.s
 ; CHECK-NEXT:    uunpkhi z1.d, z2.s
-; CHECK-NEXT:    punpklo p5.h, p0.b
+; CHECK-NEXT:    punpklo p6.h, p0.b
 ; CHECK-NEXT:    uunpklo z2.d, z3.s
 ; CHECK-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEXT:    punpkhi p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    uunpklo z4.d, z5.s
 ; CHECK-NEXT:    uunpkhi z5.d, z5.s
 ; CHECK-NEXT:    uunpklo z6.d, z7.s
 ; CHECK-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEXT:    punpklo p0.h, p2.b
-; CHECK-NEXT:    punpkhi p1.h, p2.b
-; CHECK-NEXT:    punpklo p2.h, p3.b
-; CHECK-NEXT:    punpkhi p3.h, p3.b
-; CHECK-NEXT:    punpklo p4.h, p5.b
-; CHECK-NEXT:    punpkhi p5.h, p5.b
-; CHECK-NEXT:    punpklo p6.h, p7.b
-; CHECK-NEXT:    punpkhi p7.h, p7.b
+; CHECK-NEXT:    punpklo p1.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    punpklo p3.h, p4.b
+; CHECK-NEXT:    punpkhi p4.h, p4.b
+; CHECK-NEXT:    punpklo p5.h, p6.b
+; CHECK-NEXT:    punpkhi p6.h, p6.b
+; CHECK-NEXT:    punpklo p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:  .LBB0_1: // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    st1b { z0.d }, p0, [z16.d]
-; CHECK-NEXT:    st1b { z1.d }, p1, [z16.d]
-; CHECK-NEXT:    st1b { z2.d }, p2, [z16.d]
-; CHECK-NEXT:    st1b { z3.d }, p3, [z16.d]
-; CHECK-NEXT:    st1b { z4.d }, p4, [z16.d]
-; CHECK-NEXT:    st1b { z5.d }, p5, [z16.d]
-; CHECK-NEXT:    st1b { z6.d }, p6, [z16.d]
-; CHECK-NEXT:    st1b { z7.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z0.d }, p1, [z16.d]
+; CHECK-NEXT:    st1b { z1.d }, p2, [z16.d]
+; CHECK-NEXT:    st1b { z2.d }, p3, [z16.d]
+; CHECK-NEXT:    st1b { z3.d }, p4, [z16.d]
+; CHECK-NEXT:    st1b { z4.d }, p5, [z16.d]
+; CHECK-NEXT:    st1b { z5.d }, p6, [z16.d]
+; CHECK-NEXT:    st1b { z6.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z7.d }, p0, [z16.d]
+; CHECK-NEXT:    str p8, [x0]
 ; CHECK-NEXT:    b .LBB0_1
   br label %1
 
@@ -66,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
   %constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
   %constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
   call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
+  store <vscale x 16 x i1> %constexpr, ptr %p, align 16
   br label %1
 }
 
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
new file mode 100644
index 00000000000000..12bd2db2297d77
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
@@ -0,0 +1,204 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+
+define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #5
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat (i32 5)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_lhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #7
+; CHECK-NEXT:    cset w0, hi
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> splat(i32 7), %a
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_vec_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #234
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, <i32 5, i32 234, i32 -1, i32 7>
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_fcmp_v4f32_const_splat_rhs(<4 x float> %a) {
+; CHECK-LABEL: extract_fcmp_v4f32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov s0, v0.s[1]
+; CHECK-NEXT:    fmov s1, #4.00000000
+; CHECK-NEXT:    fcmp s0, s1
+; CHECK-NEXT:    cset w0, lt
+; CHECK-NEXT:    ret
+  %fcmp = fcmp ult <4 x float> %a, splat(float 4.0e+0)
+  %ext = extractelement <4 x i1> %fcmp, i32 1
+  ret i1 %ext
+}
+
+; Tests the code in ExpandIntRes_SETCC
+define i128 @extract_icmp_v1i128(ptr %p) {
+; CHECK-LABEL: extract_icmp_v1i128:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldp x9, x8, [x0]
+; CHECK-NEXT:    mov x1, xzr
+; CHECK-NEXT:    orr x8, x9, x8
+; CHECK-NEXT:    cmp x8, #0
+; CHECK-NEXT:    cset w0, eq
+; CHECK-NEXT:    ret
+  %load = load <1 x i128>, ptr %p, align 16
+  %cmp = icmp eq <1 x i128> %load, zeroinitializer
+  %sext = sext <1 x i1> %cmp to <1 x i128>
+  %res = extractelement <1 x i128> %sext, i32 0
+  ret i128 %res
+}
+
+define void @vector_loop_with_icmp(ptr nocapture noundef writeonly %dest) {
+; CHECK-LABEL: vector_loop_with_icmp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.d, #0, #1
+; CHECK-NEXT:    mov w8, #2 // =0x2
+; CHECK-NEXT:    mov w9, #16 // =0x10
+; CHECK-NEXT:    dup v1.2d, x8
+; CHECK-NEXT:    add x8, x0, #4
+; CHECK-NEXT:    mov w10, #1 // =0x1
+; CHECK-NEXT:    b .LBB5_2
+; CHECK-NEXT:  .LBB5_1: // %pred.store.continue6
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    add v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    subs x9, x9, #2
+; CHECK-NEXT:    add x8, x8, #8
+; CHECK-NEXT:    b.eq .LBB5_6
+; CHECK-NEXT:  .LBB5_2: // %vector.body
+; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    fmov x11, d0
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_4
+; CHECK-NEXT:  // %bb.3: // %pred.store.if
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    stur w10, [x8, #-4]
+; CHECK-NEXT:  .LBB5_4: // %pred.store.continue
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    mov x11, v0.d[1]
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_1
+; CHECK-NEXT:  // %bb.5: // %pred.store.if5
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    str w10, [x8]
+; CHECK-NEXT:    b .LBB5_1
+; CHECK-NEXT:  .LBB5_6: // %for.cond.cleanup
+; CHECK-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %pred.store.continue6 ]
+  %vec.ind = phi <2 x i64> [ <i64 0, i64 1>, %entry ], [ %vec.ind.next, %pred.store.continue6 ]
+  %vec.cmp = icmp ult <2 x i64> %vec.ind, <i64 15, i64 15>
+  %c0 = extractelement <2 x i1> %vec.cmp, i64 0
+  br i1 %c0, label %pred.store.if, label %pred.store.continue
+
+pred.store.if:
+  %arrayidx = getelementptr inbounds i32, ptr %dest, i64 %index
+  store i32 1, ptr %arrayidx, align 4
+  br label %pred.store.continue
+
+pred.store.continue:
+  %c1 = extractelement <2 x i1> %vec.cmp, i64 1
+  br i1 %c1, label %pred.store.if5, label %pred.store.continue6
+
+pred.store.if5:
+  %indexp1 = or disjoint i64 %index, 1
+  %arrayidx2 = getelementptr inbounds i32, ptr %dest, i64 %indexp1
+  store i32 1, ptr %arrayidx2, align 4
+  br label %pred.store.continue6
+
+pred.store.continue6:
+  %index.next = add i64 %index, 2
+  %vec.ind.next = add <2 x i64> %vec.ind, <i64 2, i64 2>
+  %index.cmp = icmp eq i64 %index.next, 16
+  br i1 %index.cmp, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:
+  ret void
+}
+
+
+; Negative tests
+
+define i1 @extract_icmp_v4i32_splat_rhs(<4 x i32> %a, i32 %b) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v1.4s, w0
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %ins = insertelement <4 x i32> poison, i32 %b, i32 0
+  %splat = shufflevector <4 x i32> %ins, <4 x i32> poison, <4 x i32> zeroinitializer
+  %icmp = icmp ult <4 x i32> %a, %splat
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_mul_use(<4 x i32> %a, ptr %p) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_mul_use:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v1.4s, #235
+; CHECK-NEXT:    adrp x9, .LCPI7_0
+; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI7_0]
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v1.4h, v0.4s
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    umov w9, v1.h[1]
+; CHECK-NEXT:    fmov w10, s0
+; CHECK-NEXT:    and w0, w9, #0x1
+; CHECK-NEXT:    strb w10, [x8]
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 235)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  store <4 x i1> %icmp, ptr %p, align 4
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_unknown_idx(<4 x i32> %a, i32 %c) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_unknown_idx:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    movi v1.4s, #127
+; CHECK-NEXT:    add x8, sp, #8
+; CHECK-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-NEXT:    bfi x8, x0, #1, #2
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    str d0, [sp, #8]
+; CHECK-NEXT:    ldrh w8, [x8]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 127)
+  %ext = extractelement <4 x i1> %icmp, i32 %c
+  ret i1 %ext
+}
+
diff --git a/llvm/test/CodeGen/X86/vselect.ll b/llvm/test/CodeGen/X86/vselect.ll
index 9acd995d612c31..be6ee8f6899584 100644
--- a/llvm/test/CodeGen/X86/vselect.ll
+++ b/llvm/test/CodeGen/X86/vselect.ll
@@ -796,3 +796,29 @@ define i64 @vselect_any_extend_vector_inreg_crash(ptr %x) {
   ret i64 %4
 }
 
+; Tests the scalarizeBinOp code in DAGCombiner
+define void @scalarize_binop(<1 x i1> %a) {
+; SSE-LABEL: scalarize_binop:
+; SSE:       # %bb.0: # %bb0
+; SSE-NEXT:    .p2align 4
+; SSE-NEXT:  .LBB35_1: # %bb1
+; SSE-NEXT:    # =>This Inner Loop Header: Depth=1
+; SSE-NEXT:    jmp .LBB35_1
+;
+; AVX-LABEL: scalarize_binop:
+; AVX:       # %bb.0: # %bb0
+; AVX-NEXT:    .p2align 4
+; AVX-NEXT:  .LBB35_1: # %bb1
+; AVX-NEXT:    # =>This Inner Loop Header: Depth=1
+; AVX-NEXT:    jmp .LBB35_1
+bb0:
+  br label %bb1
+
+bb1:
+  %b = select <1 x i1> %a, <1 x i1> zeroinitializer, <1 x i1> splat (i1 true)
+  br label %bb2
+
+bb2:
+  %c...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-backend-aarch64

Author: David Sherwood (david-arm)

Changes

For IR like this:

%icmp = icmp ult <4 x i32> %a, splat (i32 5)
%res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

%ext = extractelement <4 x i32> %a, i32 1
%res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.


Patch is 20.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118823.diff

10 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+27-16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+17)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+4)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll (+24-22)
  • (added) llvm/test/CodeGen/AArch64/extract-vector-cmp.ll (+204)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+26)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 48018ac29bd089..1aab0fce4e1e53 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22751,16 +22751,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
 
 /// Transform a vector binary operation into a scalar binary operation by moving
 /// the math/logic after an extract element of a vector.
-static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
-                                       const SDLoc &DL, bool LegalOperations) {
+static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
+                                       const SDLoc &DL, bool LegalTypes) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   SDValue Vec = ExtElt->getOperand(0);
   SDValue Index = ExtElt->getOperand(1);
   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
+  unsigned Opc = Vec.getOpcode();
+  if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
       Vec->getNumValues() != 1)
     return SDValue();
 
+  EVT ResVT = ExtElt->getValueType(0);
+  if (Opc == ISD::SETCC &&
+      (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
+    return SDValue();
+
   // Targets may want to avoid this to prevent an expensive register transfer.
   if (!TLI.shouldScalarizeBinop(Vec))
     return SDValue();
@@ -22771,19 +22777,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
   SDValue Op0 = Vec.getOperand(0);
   SDValue Op1 = Vec.getOperand(1);
   APInt SplatVal;
-  if (isAnyConstantBuildVector(Op0, true) ||
-      ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
-      isAnyConstantBuildVector(Op1, true) ||
-      ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
-    // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
-    // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
-    EVT VT = ExtElt->getValueType(0);
-    SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
-    SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
-    return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
-  }
+  if (!isAnyConstantBuildVector(Op0, true) &&
+      !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
+      !isAnyConstantBuildVector(Op1, true) &&
+      !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
+    return SDValue();
 
-  return SDValue();
+  // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
+  // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
+  if (Opc == ISD::SETCC) {
+    EVT OpVT = Op0.getValueType().getVectorElementType();
+    Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
+    Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
+    return DAG.getSetCC(DL, ResVT, Op0, Op1,
+                        cast<CondCodeSDNode>(Vec->getOperand(2))->get());
+  }
+  Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
+  Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
+  return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
 }
 
 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23016,7 +23027,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
     }
   }
 
-  if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
+  if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
     return BO;
 
   if (VecVT.isScalableVector())
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 986d69e6c7a9e0..f8021c49a138ac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SELECT_CC:    SplitRes_SELECT_CC(N, Lo, Hi); break;
   case ISD::UNDEF:        SplitRes_UNDEF(N, Lo, Hi); break;
   case ISD::FREEZE:       SplitRes_FREEZE(N, Lo, Hi); break;
+  case ISD::SETCC:        ExpandIntRes_SETCC(N, Lo, Hi); break;
 
   case ISD::BITCAST:            ExpandRes_BITCAST(N, Lo, Hi); break;
   case ISD::BUILD_PAIR:         ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
@@ -3316,6 +3317,22 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
   }
 }
 
+void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
+  SDLoc DL(N);
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  EVT NewVT = getSetCCResultType(LHS.getValueType());
+
+  // Taking the same approach as ScalarizeVecRes_SETCC
+  SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));
+
+  ISD::NodeType ExtendCode =
+      TargetLowering::getExtendForContent(TLI.getBooleanContents(NewVT));
+  Res = DAG.getExtOrTrunc(Res, DL, N->getValueType(0), ExtendCode);
+  SplitInteger(Res, Lo, Hi);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
                                            SDValue &Lo, SDValue &Hi) {
   SDLoc DL(N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 1703149aca7463..571a710cc92a34 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void ExpandIntRes_MINMAX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_CMP               (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_SETCC             (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index cb0b9e965277aa..d51b36f7e49946 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 
   bool softPromoteHalfType() const override { return true; }
+
+  bool shouldScalarizeBinop(SDValue VecOp) const override {
+    return VecOp.getOpcode() == ISD::SETCC;
+  }
 };
 
 namespace AArch64 {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2b2daad1b8987..cfa9620c419de7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index c765d2b1ab95bc..7712570869ff6c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c18a4ac9acb1e4..8d693ac64321ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
index 5a5dee0b53d439..4cb1d5b2fb345d 100644
--- a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
+++ b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
@@ -5,7 +5,7 @@
 
 declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)
 
-define fastcc i8 @allocno_reload_assign() {
+define fastcc i8 @allocno_reload_assign(ptr %p) {
 ; CHECK-LABEL: allocno_reload_assign:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmov d0, xzr
@@ -14,8 +14,8 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    uzp1 p0.s, p0.s, p0.s
 ; CHECK-NEXT:    uzp1 p0.h, p0.h, p0.h
-; CHECK-NEXT:    uzp1 p0.b, p0.b, p0.b
-; CHECK-NEXT:    mov z0.b, p0/z, #1 // =0x1
+; CHECK-NEXT:    uzp1 p8.b, p0.b, p0.b
+; CHECK-NEXT:    mov z0.b, p8/z, #1 // =0x1
 ; CHECK-NEXT:    fmov w8, s0
 ; CHECK-NEXT:    mov z0.b, #0 // =0x0
 ; CHECK-NEXT:    uunpklo z1.h, z0.b
@@ -30,34 +30,35 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    punpklo p1.h, p0.b
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    punpklo p2.h, p1.b
-; CHECK-NEXT:    punpkhi p3.h, p1.b
+; CHECK-NEXT:    punpkhi p4.h, p1.b
 ; CHECK-NEXT:    uunpklo z0.d, z2.s
 ; CHECK-NEXT:    uunpkhi z1.d, z2.s
-; CHECK-NEXT:    punpklo p5.h, p0.b
+; CHECK-NEXT:    punpklo p6.h, p0.b
 ; CHECK-NEXT:    uunpklo z2.d, z3.s
 ; CHECK-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEXT:    punpkhi p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    uunpklo z4.d, z5.s
 ; CHECK-NEXT:    uunpkhi z5.d, z5.s
 ; CHECK-NEXT:    uunpklo z6.d, z7.s
 ; CHECK-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEXT:    punpklo p0.h, p2.b
-; CHECK-NEXT:    punpkhi p1.h, p2.b
-; CHECK-NEXT:    punpklo p2.h, p3.b
-; CHECK-NEXT:    punpkhi p3.h, p3.b
-; CHECK-NEXT:    punpklo p4.h, p5.b
-; CHECK-NEXT:    punpkhi p5.h, p5.b
-; CHECK-NEXT:    punpklo p6.h, p7.b
-; CHECK-NEXT:    punpkhi p7.h, p7.b
+; CHECK-NEXT:    punpklo p1.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    punpklo p3.h, p4.b
+; CHECK-NEXT:    punpkhi p4.h, p4.b
+; CHECK-NEXT:    punpklo p5.h, p6.b
+; CHECK-NEXT:    punpkhi p6.h, p6.b
+; CHECK-NEXT:    punpklo p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:  .LBB0_1: // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    st1b { z0.d }, p0, [z16.d]
-; CHECK-NEXT:    st1b { z1.d }, p1, [z16.d]
-; CHECK-NEXT:    st1b { z2.d }, p2, [z16.d]
-; CHECK-NEXT:    st1b { z3.d }, p3, [z16.d]
-; CHECK-NEXT:    st1b { z4.d }, p4, [z16.d]
-; CHECK-NEXT:    st1b { z5.d }, p5, [z16.d]
-; CHECK-NEXT:    st1b { z6.d }, p6, [z16.d]
-; CHECK-NEXT:    st1b { z7.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z0.d }, p1, [z16.d]
+; CHECK-NEXT:    st1b { z1.d }, p2, [z16.d]
+; CHECK-NEXT:    st1b { z2.d }, p3, [z16.d]
+; CHECK-NEXT:    st1b { z3.d }, p4, [z16.d]
+; CHECK-NEXT:    st1b { z4.d }, p5, [z16.d]
+; CHECK-NEXT:    st1b { z5.d }, p6, [z16.d]
+; CHECK-NEXT:    st1b { z6.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z7.d }, p0, [z16.d]
+; CHECK-NEXT:    str p8, [x0]
 ; CHECK-NEXT:    b .LBB0_1
   br label %1
 
@@ -66,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
   %constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
   %constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
   call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
+  store <vscale x 16 x i1> %constexpr, ptr %p, align 16
   br label %1
 }
 
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
new file mode 100644
index 00000000000000..12bd2db2297d77
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
@@ -0,0 +1,204 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+
+define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #5
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat (i32 5)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_lhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #7
+; CHECK-NEXT:    cset w0, hi
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> splat(i32 7), %a
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_vec_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #234
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, <i32 5, i32 234, i32 -1, i32 7>
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_fcmp_v4f32_const_splat_rhs(<4 x float> %a) {
+; CHECK-LABEL: extract_fcmp_v4f32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov s0, v0.s[1]
+; CHECK-NEXT:    fmov s1, #4.00000000
+; CHECK-NEXT:    fcmp s0, s1
+; CHECK-NEXT:    cset w0, lt
+; CHECK-NEXT:    ret
+  %fcmp = fcmp ult <4 x float> %a, splat(float 4.0e+0)
+  %ext = extractelement <4 x i1> %fcmp, i32 1
+  ret i1 %ext
+}
+
+; Tests the code in ExpandIntRes_SETCC
+define i128 @extract_icmp_v1i128(ptr %p) {
+; CHECK-LABEL: extract_icmp_v1i128:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldp x9, x8, [x0]
+; CHECK-NEXT:    mov x1, xzr
+; CHECK-NEXT:    orr x8, x9, x8
+; CHECK-NEXT:    cmp x8, #0
+; CHECK-NEXT:    cset w0, eq
+; CHECK-NEXT:    ret
+  %load = load <1 x i128>, ptr %p, align 16
+  %cmp = icmp eq <1 x i128> %load, zeroinitializer
+  %sext = sext <1 x i1> %cmp to <1 x i128>
+  %res = extractelement <1 x i128> %sext, i32 0
+  ret i128 %res
+}
+
+define void @vector_loop_with_icmp(ptr nocapture noundef writeonly %dest) {
+; CHECK-LABEL: vector_loop_with_icmp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.d, #0, #1
+; CHECK-NEXT:    mov w8, #2 // =0x2
+; CHECK-NEXT:    mov w9, #16 // =0x10
+; CHECK-NEXT:    dup v1.2d, x8
+; CHECK-NEXT:    add x8, x0, #4
+; CHECK-NEXT:    mov w10, #1 // =0x1
+; CHECK-NEXT:    b .LBB5_2
+; CHECK-NEXT:  .LBB5_1: // %pred.store.continue6
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    add v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    subs x9, x9, #2
+; CHECK-NEXT:    add x8, x8, #8
+; CHECK-NEXT:    b.eq .LBB5_6
+; CHECK-NEXT:  .LBB5_2: // %vector.body
+; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    fmov x11, d0
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_4
+; CHECK-NEXT:  // %bb.3: // %pred.store.if
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    stur w10, [x8, #-4]
+; CHECK-NEXT:  .LBB5_4: // %pred.store.continue
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    mov x11, v0.d[1]
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_1
+; CHECK-NEXT:  // %bb.5: // %pred.store.if5
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    str w10, [x8]
+; CHECK-NEXT:    b .LBB5_1
+; CHECK-NEXT:  .LBB5_6: // %for.cond.cleanup
+; CHECK-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %pred.store.continue6 ]
+  %vec.ind = phi <2 x i64> [ <i64 0, i64 1>, %entry ], [ %vec.ind.next, %pred.store.continue6 ]
+  %vec.cmp = icmp ult <2 x i64> %vec.ind, <i64 15, i64 15>
+  %c0 = extractelement <2 x i1> %vec.cmp, i64 0
+  br i1 %c0, label %pred.store.if, label %pred.store.continue
+
+pred.store.if:
+  %arrayidx = getelementptr inbounds i32, ptr %dest, i64 %index
+  store i32 1, ptr %arrayidx, align 4
+  br label %pred.store.continue
+
+pred.store.continue:
+  %c1 = extractelement <2 x i1> %vec.cmp, i64 1
+  br i1 %c1, label %pred.store.if5, label %pred.store.continue6
+
+pred.store.if5:
+  %indexp1 = or disjoint i64 %index, 1
+  %arrayidx2 = getelementptr inbounds i32, ptr %dest, i64 %indexp1
+  store i32 1, ptr %arrayidx2, align 4
+  br label %pred.store.continue6
+
+pred.store.continue6:
+  %index.next = add i64 %index, 2
+  %vec.ind.next = add <2 x i64> %vec.ind, <i64 2, i64 2>
+  %index.cmp = icmp eq i64 %index.next, 16
+  br i1 %index.cmp, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:
+  ret void
+}
+
+
+; Negative tests
+
+define i1 @extract_icmp_v4i32_splat_rhs(<4 x i32> %a, i32 %b) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v1.4s, w0
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %ins = insertelement <4 x i32> poison, i32 %b, i32 0
+  %splat = shufflevector <4 x i32> %ins, <4 x i32> poison, <4 x i32> zeroinitializer
+  %icmp = icmp ult <4 x i32> %a, %splat
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_mul_use(<4 x i32> %a, ptr %p) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_mul_use:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v1.4s, #235
+; CHECK-NEXT:    adrp x9, .LCPI7_0
+; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI7_0]
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v1.4h, v0.4s
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    umov w9, v1.h[1]
+; CHECK-NEXT:    fmov w10, s0
+; CHECK-NEXT:    and w0, w9, #0x1
+; CHECK-NEXT:    strb w10, [x8]
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 235)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  store <4 x i1> %icmp, ptr %p, align 4
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_unknown_idx(<4 x i32> %a, i32 %c) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_unknown_idx:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    movi v1.4s, #127
+; CHECK-NEXT:    add x8, sp, #8
+; CHECK-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-NEXT:    bfi x8, x0, #1, #2
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    str d0, [sp, #8]
+; CHECK-NEXT:    ldrh w8, [x8]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 127)
+  %ext = extractelement <4 x i1> %icmp, i32 %c
+  ret i1 %ext
+}
+
diff --git a/llvm/test/CodeGen/X86/vselect.ll b/llvm/test/CodeGen/X86/vselect.ll
index 9acd995d612c31..be6ee8f6899584 100644
--- a/llvm/test/CodeGen/X86/vselect.ll
+++ b/llvm/test/CodeGen/X86/vselect.ll
@@ -796,3 +796,29 @@ define i64 @vselect_any_extend_vector_inreg_crash(ptr %x) {
   ret i64 %4
 }
 
+; Tests the scalarizeBinOp code in DAGCombiner
+define void @scalarize_binop(<1 x i1> %a) {
+; SSE-LABEL: scalarize_binop:
+; SSE:       # %bb.0: # %bb0
+; SSE-NEXT:    .p2align 4
+; SSE-NEXT:  .LBB35_1: # %bb1
+; SSE-NEXT:    # =>This Inner Loop Header: Depth=1
+; SSE-NEXT:    jmp .LBB35_1
+;
+; AVX-LABEL: scalarize_binop:
+; AVX:       # %bb.0: # %bb0
+; AVX-NEXT:    .p2align 4
+; AVX-NEXT:  .LBB35_1: # %bb1
+; AVX-NEXT:    # =>This Inner Loop Header: Depth=1
+; AVX-NEXT:    jmp .LBB35_1
+bb0:
+  br label %bb1
+
+bb1:
+  %b = select <1 x i1> %a, <1 x i1> zeroinitializer, <1 x i1> splat (i1 true)
+  br label %bb2
+
+bb2:
+  %c...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Dec 5, 2024

@llvm/pr-subscribers-llvm-selectiondag

Author: David Sherwood (david-arm)

Changes

For IR like this:

%icmp = icmp ult <4 x i32> %a, splat (i32 5)
%res = extractelement <4 x i1> %icmp, i32 1

where there is only one use of %icmp we can take a similar approach
to what we already do for binary ops such add, sub, etc. and convert
this into

%ext = extractelement <4 x i32> %a, i32 1
%res = icmp ult i32 %ext, 5

For AArch64 targets at least the scalar boolean result will almost
certainly need to be in a GPR anyway, since it will probably be
used by branches for control flow. I've tried to reuse existing code
in scalarizeExtractedBinop to also work for setcc.

NOTE: The optimisations don't apply for tests such as
extract_icmp_v4i32_splat_rhs in the file

CodeGen/AArch64/extract-vector-cmp.ll

because scalarizeExtractedBinOp only works if one of the input
operands is a constant.


Patch is 20.06 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/118823.diff

10 Files Affected:

  • (modified) llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp (+27-16)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp (+17)
  • (modified) llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h (+1)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+4)
  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp (+1-1)
  • (modified) llvm/lib/Target/X86/X86ISelLowering.cpp (+1-1)
  • (modified) llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll (+24-22)
  • (added) llvm/test/CodeGen/AArch64/extract-vector-cmp.ll (+204)
  • (modified) llvm/test/CodeGen/X86/vselect.ll (+26)
diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
index 48018ac29bd089..1aab0fce4e1e53 100644
--- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
@@ -22751,16 +22751,22 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
 
 /// Transform a vector binary operation into a scalar binary operation by moving
 /// the math/logic after an extract element of a vector.
-static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
-                                       const SDLoc &DL, bool LegalOperations) {
+static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
+                                       const SDLoc &DL, bool LegalTypes) {
   const TargetLowering &TLI = DAG.getTargetLoweringInfo();
   SDValue Vec = ExtElt->getOperand(0);
   SDValue Index = ExtElt->getOperand(1);
   auto *IndexC = dyn_cast<ConstantSDNode>(Index);
-  if (!IndexC || !TLI.isBinOp(Vec.getOpcode()) || !Vec.hasOneUse() ||
+  unsigned Opc = Vec.getOpcode();
+  if (!IndexC || !Vec.hasOneUse() || (!TLI.isBinOp(Opc) && Opc != ISD::SETCC) ||
       Vec->getNumValues() != 1)
     return SDValue();
 
+  EVT ResVT = ExtElt->getValueType(0);
+  if (Opc == ISD::SETCC &&
+      (ResVT != Vec.getValueType().getVectorElementType() || LegalTypes))
+    return SDValue();
+
   // Targets may want to avoid this to prevent an expensive register transfer.
   if (!TLI.shouldScalarizeBinop(Vec))
     return SDValue();
@@ -22771,19 +22777,24 @@ static SDValue scalarizeExtractedBinop(SDNode *ExtElt, SelectionDAG &DAG,
   SDValue Op0 = Vec.getOperand(0);
   SDValue Op1 = Vec.getOperand(1);
   APInt SplatVal;
-  if (isAnyConstantBuildVector(Op0, true) ||
-      ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
-      isAnyConstantBuildVector(Op1, true) ||
-      ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
-    // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
-    // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
-    EVT VT = ExtElt->getValueType(0);
-    SDValue Ext0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
-    SDValue Ext1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
-    return DAG.getNode(Vec.getOpcode(), DL, VT, Ext0, Ext1);
-  }
+  if (!isAnyConstantBuildVector(Op0, true) &&
+      !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
+      !isAnyConstantBuildVector(Op1, true) &&
+      !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
+    return SDValue();
 
-  return SDValue();
+  // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
+  // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
+  if (Opc == ISD::SETCC) {
+    EVT OpVT = Op0.getValueType().getVectorElementType();
+    Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
+    Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
+    return DAG.getSetCC(DL, ResVT, Op0, Op1,
+                        cast<CondCodeSDNode>(Vec->getOperand(2))->get());
+  }
+  Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op0, Index);
+  Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, ResVT, Op1, Index);
+  return DAG.getNode(Opc, DL, ResVT, Op0, Op1);
 }
 
 // Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23016,7 +23027,7 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
     }
   }
 
-  if (SDValue BO = scalarizeExtractedBinop(N, DAG, DL, LegalOperations))
+  if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL, LegalTypes))
     return BO;
 
   if (VecVT.isScalableVector())
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 986d69e6c7a9e0..f8021c49a138ac 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2835,6 +2835,7 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SELECT_CC:    SplitRes_SELECT_CC(N, Lo, Hi); break;
   case ISD::UNDEF:        SplitRes_UNDEF(N, Lo, Hi); break;
   case ISD::FREEZE:       SplitRes_FREEZE(N, Lo, Hi); break;
+  case ISD::SETCC:        ExpandIntRes_SETCC(N, Lo, Hi); break;
 
   case ISD::BITCAST:            ExpandRes_BITCAST(N, Lo, Hi); break;
   case ISD::BUILD_PAIR:         ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
@@ -3316,6 +3317,22 @@ static std::pair<ISD::CondCode, ISD::NodeType> getExpandedMinMaxOps(int Op) {
   }
 }
 
+void DAGTypeLegalizer::ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi) {
+  SDLoc DL(N);
+
+  SDValue LHS = N->getOperand(0);
+  SDValue RHS = N->getOperand(1);
+  EVT NewVT = getSetCCResultType(LHS.getValueType());
+
+  // Taking the same approach as ScalarizeVecRes_SETCC
+  SDValue Res = DAG.getNode(ISD::SETCC, DL, NewVT, LHS, RHS, N->getOperand(2));
+
+  ISD::NodeType ExtendCode =
+      TargetLowering::getExtendForContent(TLI.getBooleanContents(NewVT));
+  Res = DAG.getExtOrTrunc(Res, DL, N->getValueType(0), ExtendCode);
+  SplitInteger(Res, Lo, Hi);
+}
+
 void DAGTypeLegalizer::ExpandIntRes_MINMAX(SDNode *N,
                                            SDValue &Lo, SDValue &Hi) {
   SDLoc DL(N);
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 1703149aca7463..571a710cc92a34 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -487,6 +487,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
   void ExpandIntRes_MINMAX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_CMP               (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_SETCC             (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index cb0b9e965277aa..d51b36f7e49946 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1348,6 +1348,10 @@ class AArch64TargetLowering : public TargetLowering {
   unsigned getMinimumJumpTableEntries() const override;
 
   bool softPromoteHalfType() const override { return true; }
+
+  bool shouldScalarizeBinop(SDValue VecOp) const override {
+    return VecOp.getOpcode() == ISD::SETCC;
+  }
 };
 
 namespace AArch64 {
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index c2b2daad1b8987..cfa9620c419de7 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -2093,7 +2093,7 @@ bool RISCVTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
index c765d2b1ab95bc..7712570869ff6c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp
@@ -429,7 +429,7 @@ bool WebAssemblyTargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp
index c18a4ac9acb1e4..8d693ac64321ff 100644
--- a/llvm/lib/Target/X86/X86ISelLowering.cpp
+++ b/llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -3306,7 +3306,7 @@ bool X86TargetLowering::shouldScalarizeBinop(SDValue VecOp) const {
 
   // Assume target opcodes can't be scalarized.
   // TODO - do we have any exceptions?
-  if (Opc >= ISD::BUILTIN_OP_END)
+  if (Opc >= ISD::BUILTIN_OP_END || !isBinOp(Opc))
     return false;
 
   // If the vector op is not supported, try to convert to scalar.
diff --git a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
index 5a5dee0b53d439..4cb1d5b2fb345d 100644
--- a/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
+++ b/llvm/test/CodeGen/AArch64/dag-combine-concat-vectors.ll
@@ -5,7 +5,7 @@
 
 declare void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8>, <vscale x 16 x ptr>, i32 immarg, <vscale x 16 x i1>)
 
-define fastcc i8 @allocno_reload_assign() {
+define fastcc i8 @allocno_reload_assign(ptr %p) {
 ; CHECK-LABEL: allocno_reload_assign:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fmov d0, xzr
@@ -14,8 +14,8 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    cmpeq p0.d, p0/z, z0.d, #0
 ; CHECK-NEXT:    uzp1 p0.s, p0.s, p0.s
 ; CHECK-NEXT:    uzp1 p0.h, p0.h, p0.h
-; CHECK-NEXT:    uzp1 p0.b, p0.b, p0.b
-; CHECK-NEXT:    mov z0.b, p0/z, #1 // =0x1
+; CHECK-NEXT:    uzp1 p8.b, p0.b, p0.b
+; CHECK-NEXT:    mov z0.b, p8/z, #1 // =0x1
 ; CHECK-NEXT:    fmov w8, s0
 ; CHECK-NEXT:    mov z0.b, #0 // =0x0
 ; CHECK-NEXT:    uunpklo z1.h, z0.b
@@ -30,34 +30,35 @@ define fastcc i8 @allocno_reload_assign() {
 ; CHECK-NEXT:    punpklo p1.h, p0.b
 ; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    punpklo p2.h, p1.b
-; CHECK-NEXT:    punpkhi p3.h, p1.b
+; CHECK-NEXT:    punpkhi p4.h, p1.b
 ; CHECK-NEXT:    uunpklo z0.d, z2.s
 ; CHECK-NEXT:    uunpkhi z1.d, z2.s
-; CHECK-NEXT:    punpklo p5.h, p0.b
+; CHECK-NEXT:    punpklo p6.h, p0.b
 ; CHECK-NEXT:    uunpklo z2.d, z3.s
 ; CHECK-NEXT:    uunpkhi z3.d, z3.s
-; CHECK-NEXT:    punpkhi p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:    uunpklo z4.d, z5.s
 ; CHECK-NEXT:    uunpkhi z5.d, z5.s
 ; CHECK-NEXT:    uunpklo z6.d, z7.s
 ; CHECK-NEXT:    uunpkhi z7.d, z7.s
-; CHECK-NEXT:    punpklo p0.h, p2.b
-; CHECK-NEXT:    punpkhi p1.h, p2.b
-; CHECK-NEXT:    punpklo p2.h, p3.b
-; CHECK-NEXT:    punpkhi p3.h, p3.b
-; CHECK-NEXT:    punpklo p4.h, p5.b
-; CHECK-NEXT:    punpkhi p5.h, p5.b
-; CHECK-NEXT:    punpklo p6.h, p7.b
-; CHECK-NEXT:    punpkhi p7.h, p7.b
+; CHECK-NEXT:    punpklo p1.h, p2.b
+; CHECK-NEXT:    punpkhi p2.h, p2.b
+; CHECK-NEXT:    punpklo p3.h, p4.b
+; CHECK-NEXT:    punpkhi p4.h, p4.b
+; CHECK-NEXT:    punpklo p5.h, p6.b
+; CHECK-NEXT:    punpkhi p6.h, p6.b
+; CHECK-NEXT:    punpklo p7.h, p0.b
+; CHECK-NEXT:    punpkhi p0.h, p0.b
 ; CHECK-NEXT:  .LBB0_1: // =>This Inner Loop Header: Depth=1
-; CHECK-NEXT:    st1b { z0.d }, p0, [z16.d]
-; CHECK-NEXT:    st1b { z1.d }, p1, [z16.d]
-; CHECK-NEXT:    st1b { z2.d }, p2, [z16.d]
-; CHECK-NEXT:    st1b { z3.d }, p3, [z16.d]
-; CHECK-NEXT:    st1b { z4.d }, p4, [z16.d]
-; CHECK-NEXT:    st1b { z5.d }, p5, [z16.d]
-; CHECK-NEXT:    st1b { z6.d }, p6, [z16.d]
-; CHECK-NEXT:    st1b { z7.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z0.d }, p1, [z16.d]
+; CHECK-NEXT:    st1b { z1.d }, p2, [z16.d]
+; CHECK-NEXT:    st1b { z2.d }, p3, [z16.d]
+; CHECK-NEXT:    st1b { z3.d }, p4, [z16.d]
+; CHECK-NEXT:    st1b { z4.d }, p5, [z16.d]
+; CHECK-NEXT:    st1b { z5.d }, p6, [z16.d]
+; CHECK-NEXT:    st1b { z6.d }, p7, [z16.d]
+; CHECK-NEXT:    st1b { z7.d }, p0, [z16.d]
+; CHECK-NEXT:    str p8, [x0]
 ; CHECK-NEXT:    b .LBB0_1
   br label %1
 
@@ -66,6 +67,7 @@ define fastcc i8 @allocno_reload_assign() {
   %constexpr1 = shufflevector <vscale x 16 x i1> %constexpr, <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer
   %constexpr2 = xor <vscale x 16 x i1> %constexpr1, shufflevector (<vscale x 16 x i1> insertelement (<vscale x 16 x i1> poison, i1 true, i64 0), <vscale x 16 x i1> poison, <vscale x 16 x i32> zeroinitializer)
   call void @llvm.masked.scatter.nxv16i8.nxv16p0(<vscale x 16 x i8> zeroinitializer, <vscale x 16 x ptr> zeroinitializer, i32 0, <vscale x 16 x i1> %constexpr2)
+  store <vscale x 16 x i1> %constexpr, ptr %p, align 16
   br label %1
 }
 
diff --git a/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
new file mode 100644
index 00000000000000..12bd2db2297d77
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/extract-vector-cmp.ll
@@ -0,0 +1,204 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mattr=+sve < %s | FileCheck %s
+
+target triple = "aarch64-unknown-linux-gnu"
+
+
+define i1 @extract_icmp_v4i32_const_splat_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #5
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat (i32 5)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_splat_lhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_splat_lhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #7
+; CHECK-NEXT:    cset w0, hi
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> splat(i32 7), %a
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_const_vec_rhs(<4 x i32> %a) {
+; CHECK-LABEL: extract_icmp_v4i32_const_vec_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov w8, v0.s[1]
+; CHECK-NEXT:    cmp w8, #234
+; CHECK-NEXT:    cset w0, lo
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, <i32 5, i32 234, i32 -1, i32 7>
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_fcmp_v4f32_const_splat_rhs(<4 x float> %a) {
+; CHECK-LABEL: extract_fcmp_v4f32_const_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    mov s0, v0.s[1]
+; CHECK-NEXT:    fmov s1, #4.00000000
+; CHECK-NEXT:    fcmp s0, s1
+; CHECK-NEXT:    cset w0, lt
+; CHECK-NEXT:    ret
+  %fcmp = fcmp ult <4 x float> %a, splat(float 4.0e+0)
+  %ext = extractelement <4 x i1> %fcmp, i32 1
+  ret i1 %ext
+}
+
+; Tests the code in ExpandIntRes_SETCC
+define i128 @extract_icmp_v1i128(ptr %p) {
+; CHECK-LABEL: extract_icmp_v1i128:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ldp x9, x8, [x0]
+; CHECK-NEXT:    mov x1, xzr
+; CHECK-NEXT:    orr x8, x9, x8
+; CHECK-NEXT:    cmp x8, #0
+; CHECK-NEXT:    cset w0, eq
+; CHECK-NEXT:    ret
+  %load = load <1 x i128>, ptr %p, align 16
+  %cmp = icmp eq <1 x i128> %load, zeroinitializer
+  %sext = sext <1 x i1> %cmp to <1 x i128>
+  %res = extractelement <1 x i128> %sext, i32 0
+  ret i128 %res
+}
+
+define void @vector_loop_with_icmp(ptr nocapture noundef writeonly %dest) {
+; CHECK-LABEL: vector_loop_with_icmp:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    index z0.d, #0, #1
+; CHECK-NEXT:    mov w8, #2 // =0x2
+; CHECK-NEXT:    mov w9, #16 // =0x10
+; CHECK-NEXT:    dup v1.2d, x8
+; CHECK-NEXT:    add x8, x0, #4
+; CHECK-NEXT:    mov w10, #1 // =0x1
+; CHECK-NEXT:    b .LBB5_2
+; CHECK-NEXT:  .LBB5_1: // %pred.store.continue6
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    add v0.2d, v0.2d, v1.2d
+; CHECK-NEXT:    subs x9, x9, #2
+; CHECK-NEXT:    add x8, x8, #8
+; CHECK-NEXT:    b.eq .LBB5_6
+; CHECK-NEXT:  .LBB5_2: // %vector.body
+; CHECK-NEXT:    // =>This Inner Loop Header: Depth=1
+; CHECK-NEXT:    fmov x11, d0
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_4
+; CHECK-NEXT:  // %bb.3: // %pred.store.if
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    stur w10, [x8, #-4]
+; CHECK-NEXT:  .LBB5_4: // %pred.store.continue
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    mov x11, v0.d[1]
+; CHECK-NEXT:    cmp x11, #14
+; CHECK-NEXT:    b.hi .LBB5_1
+; CHECK-NEXT:  // %bb.5: // %pred.store.if5
+; CHECK-NEXT:    // in Loop: Header=BB5_2 Depth=1
+; CHECK-NEXT:    str w10, [x8]
+; CHECK-NEXT:    b .LBB5_1
+; CHECK-NEXT:  .LBB5_6: // %for.cond.cleanup
+; CHECK-NEXT:    ret
+entry:
+  br label %vector.body
+
+vector.body:
+  %index = phi i64 [ 0, %entry ], [ %index.next, %pred.store.continue6 ]
+  %vec.ind = phi <2 x i64> [ <i64 0, i64 1>, %entry ], [ %vec.ind.next, %pred.store.continue6 ]
+  %vec.cmp = icmp ult <2 x i64> %vec.ind, <i64 15, i64 15>
+  %c0 = extractelement <2 x i1> %vec.cmp, i64 0
+  br i1 %c0, label %pred.store.if, label %pred.store.continue
+
+pred.store.if:
+  %arrayidx = getelementptr inbounds i32, ptr %dest, i64 %index
+  store i32 1, ptr %arrayidx, align 4
+  br label %pred.store.continue
+
+pred.store.continue:
+  %c1 = extractelement <2 x i1> %vec.cmp, i64 1
+  br i1 %c1, label %pred.store.if5, label %pred.store.continue6
+
+pred.store.if5:
+  %indexp1 = or disjoint i64 %index, 1
+  %arrayidx2 = getelementptr inbounds i32, ptr %dest, i64 %indexp1
+  store i32 1, ptr %arrayidx2, align 4
+  br label %pred.store.continue6
+
+pred.store.continue6:
+  %index.next = add i64 %index, 2
+  %vec.ind.next = add <2 x i64> %vec.ind, <i64 2, i64 2>
+  %index.cmp = icmp eq i64 %index.next, 16
+  br i1 %index.cmp, label %for.cond.cleanup, label %vector.body
+
+for.cond.cleanup:
+  ret void
+}
+
+
+; Negative tests
+
+define i1 @extract_icmp_v4i32_splat_rhs(<4 x i32> %a, i32 %b) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    dup v1.4s, w0
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    umov w8, v0.h[1]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    ret
+  %ins = insertelement <4 x i32> poison, i32 %b, i32 0
+  %splat = shufflevector <4 x i32> %ins, <4 x i32> poison, <4 x i32> zeroinitializer
+  %icmp = icmp ult <4 x i32> %a, %splat
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_mul_use(<4 x i32> %a, ptr %p) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_mul_use:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    movi v1.4s, #235
+; CHECK-NEXT:    adrp x9, .LCPI7_0
+; CHECK-NEXT:    mov x8, x0
+; CHECK-NEXT:    ldr q2, [x9, :lo12:.LCPI7_0]
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v1.4h, v0.4s
+; CHECK-NEXT:    and v0.16b, v0.16b, v2.16b
+; CHECK-NEXT:    addv s0, v0.4s
+; CHECK-NEXT:    umov w9, v1.h[1]
+; CHECK-NEXT:    fmov w10, s0
+; CHECK-NEXT:    and w0, w9, #0x1
+; CHECK-NEXT:    strb w10, [x8]
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 235)
+  %ext = extractelement <4 x i1> %icmp, i32 1
+  store <4 x i1> %icmp, ptr %p, align 4
+  ret i1 %ext
+}
+
+define i1 @extract_icmp_v4i32_splat_rhs_unknown_idx(<4 x i32> %a, i32 %c) {
+; CHECK-LABEL: extract_icmp_v4i32_splat_rhs_unknown_idx:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    sub sp, sp, #16
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    movi v1.4s, #127
+; CHECK-NEXT:    add x8, sp, #8
+; CHECK-NEXT:    // kill: def $w0 killed $w0 def $x0
+; CHECK-NEXT:    bfi x8, x0, #1, #2
+; CHECK-NEXT:    cmhi v0.4s, v1.4s, v0.4s
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    str d0, [sp, #8]
+; CHECK-NEXT:    ldrh w8, [x8]
+; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    add sp, sp, #16
+; CHECK-NEXT:    ret
+  %icmp = icmp ult <4 x i32> %a, splat(i32 127)
+  %ext = extractelement <4 x i1> %icmp, i32 %c
+  ret i1 %ext
+}
+
diff --git a/llvm/test/CodeGen/X86/vselect.ll b/llvm/test/CodeGen/X86/vselect.ll
index 9acd995d612c31..be6ee8f6899584 100644
--- a/llvm/test/CodeGen/X86/vselect.ll
+++ b/llvm/test/CodeGen/X86/vselect.ll
@@ -796,3 +796,29 @@ define i64 @vselect_any_extend_vector_inreg_crash(ptr %x) {
   ret i64 %4
 }
 
+; Tests the scalarizeBinOp code in DAGCombiner
+define void @scalarize_binop(<1 x i1> %a) {
+; SSE-LABEL: scalarize_binop:
+; SSE:       # %bb.0: # %bb0
+; SSE-NEXT:    .p2align 4
+; SSE-NEXT:  .LBB35_1: # %bb1
+; SSE-NEXT:    # =>This Inner Loop Header: Depth=1
+; SSE-NEXT:    jmp .LBB35_1
+;
+; AVX-LABEL: scalarize_binop:
+; AVX:       # %bb.0: # %bb0
+; AVX-NEXT:    .p2align 4
+; AVX-NEXT:  .LBB35_1: # %bb1
+; AVX-NEXT:    # =>This Inner Loop Header: Depth=1
+; AVX-NEXT:    jmp .LBB35_1
+bb0:
+  br label %bb1
+
+bb1:
+  %b = select <1 x i1> %a, <1 x i1> zeroinitializer, <1 x i1> splat (i1 true)
+  br label %bb2
+
+bb2:
+  %c...
[truncated]

@david-arm
Copy link
Contributor Author

Yet another attempt to reland #117566. The last post-commit failure exposed an existing issue where we were missing a ExpandIntRes_SETCC to handle i128 result types. I've added a new test for this, but it can only be exposed once the new DAG combine kicks in.

Copy link

github-actions bot commented Dec 5, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff f7685af4a5bd188e6d548967d818d8569f10a70d b40acf86cc203c8dbd3129fe2697a83ed5666740 --extensions h,cpp -- llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h llvm/lib/Target/AArch64/AArch64ISelLowering.h llvm/lib/Target/RISCV/RISCVISelLowering.cpp llvm/lib/Target/WebAssembly/WebAssemblyISelLowering.cpp llvm/lib/Target/X86/X86ISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
index 69f25ebc88..787225f8ed 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
@@ -2835,7 +2835,9 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
   case ISD::SELECT_CC:    SplitRes_SELECT_CC(N, Lo, Hi); break;
   case ISD::UNDEF:        SplitRes_UNDEF(N, Lo, Hi); break;
   case ISD::FREEZE:       SplitRes_FREEZE(N, Lo, Hi); break;
-  case ISD::SETCC:        ExpandIntRes_SETCC(N, Lo, Hi); break;
+  case ISD::SETCC:
+    ExpandIntRes_SETCC(N, Lo, Hi);
+    break;
 
   case ISD::BITCAST:            ExpandRes_BITCAST(N, Lo, Hi); break;
   case ISD::BUILD_PAIR:         ExpandRes_BUILD_PAIR(N, Lo, Hi); break;
diff --git a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
index 571a710cc9..c672b78b8e 100644
--- a/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
+++ b/llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
@@ -487,7 +487,7 @@ private:
   void ExpandIntRes_MINMAX            (SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_CMP               (SDNode *N, SDValue &Lo, SDValue &Hi);
-  void ExpandIntRes_SETCC             (SDNode *N, SDValue &Lo, SDValue &Hi);
+  void ExpandIntRes_SETCC(SDNode *N, SDValue &Lo, SDValue &Hi);
 
   void ExpandIntRes_SADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);
   void ExpandIntRes_UADDSUBO          (SDNode *N, SDValue &Lo, SDValue &Hi);

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought this already landed?

@paulwalker-arm
Copy link
Collaborator

I thought this already landed?

The patch was reverted because it exposed a bug within legalisation (expansion) of i128 based setcc operations.

@arsenm
Copy link
Contributor

arsenm commented Dec 5, 2024

I thought this already landed?

The patch was reverted because it exposed a bug within legalisation (expansion) of i128 based setcc operations.

Should have the proper git revert with the reapply in the message

@david-arm david-arm changed the title [DAGCombiner] Add support for scalarising extracts of a vector setcc Reapply "[DAGCombiner] Add support for scalarising extracts of a vector setcc" Dec 6, 2024
@david-arm david-arm changed the title Reapply "[DAGCombiner] Add support for scalarising extracts of a vector setcc" Reapply "[DAGCombiner] Add support for scalarising extracts of a vector setcc (#117566)" Dec 6, 2024
Copy link
Collaborator

@paulwalker-arm paulwalker-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully third time's the charm.

@david-arm david-arm merged commit 8630a7b into llvm:main Dec 9, 2024
7 of 8 checks passed
david-arm added a commit to david-arm/llvm-project that referenced this pull request Jan 15, 2025
PR llvm#118823 added a DAG combine for extracting elements of a
vector returned from SETCC, however it doesn't correctly deal
with the case where the vector element type is not i1. In
this case we have to take account of the boolean contents,
which are represent differently between vectors and scalars.
For now, I've just restricted the optimisation to i1 types.

Fixes llvm#121372
david-arm added a commit to david-arm/llvm-project that referenced this pull request Jan 16, 2025
PR llvm#118823 added a DAG combine for extracting elements of a
vector returned from SETCC, however it doesn't correctly deal
with the case where the vector element type is not i1. In
this case we have to take account of the boolean contents,
which are represent differently between vectors and scalars.
The code now explicitly performs an inreg sign extend in
order to get the same result.

Fixes llvm#121372
david-arm added a commit that referenced this pull request Jan 21, 2025
PR #118823 added a
DAG combine for extracting elements of a vector returned from
SETCC, however it doesn't correctly deal with the case where
the vector element type is not i1. In this case we have to
take account of the boolean contents, which are represented
differently between vectors and scalars. The code now
explicitly performs an inreg sign extend in order to get the
same result.

Fixes #121372
github-actions bot pushed a commit to arm/arm-toolchain that referenced this pull request Jan 21, 2025
…ases (#123071)

PR llvm/llvm-project#118823 added a
DAG combine for extracting elements of a vector returned from
SETCC, however it doesn't correctly deal with the case where
the vector element type is not i1. In this case we have to
take account of the boolean contents, which are represented
differently between vectors and scalars. The code now
explicitly performs an inreg sign extend in order to get the
same result.

Fixes llvm/llvm-project#121372
@david-arm david-arm deleted the extract_vcmp branch January 28, 2025 11:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants