Skip to content

[AArch64] Improve code generation of bool vector reduce operations #115713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 10, 2024

Conversation

Il-Capitano
Copy link
Contributor

  • Avoid unnecessary truncation of comparison results in vecreduce_xor
  • Optimize generated code for vecreduce_and and vecreduce_or by comparing against 0.0 to check if all/any of the values are set

Alive2 proof of vecreduce_and and vecreduce_or transformation: https://alive2.llvm.org/ce/z/SRfPtw

@llvmbot
Copy link
Member

llvmbot commented Nov 11, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Csanád Hajdú (Il-Capitano)

Changes
  • Avoid unnecessary truncation of comparison results in vecreduce_xor
  • Optimize generated code for vecreduce_and and vecreduce_or by comparing against 0.0 to check if all/any of the values are set

Alive2 proof of vecreduce_and and vecreduce_or transformation: https://alive2.llvm.org/ce/z/SRfPtw


Patch is 39.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/115713.diff

8 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+46-5)
  • (modified) llvm/test/CodeGen/AArch64/dag-combine-setcc.ll (+10-16)
  • (modified) llvm/test/CodeGen/AArch64/illegal-floating-point-vector-compares.ll (+2-4)
  • (modified) llvm/test/CodeGen/AArch64/reduce-and.ll (+9-9)
  • (modified) llvm/test/CodeGen/AArch64/reduce-or.ll (+6-9)
  • (modified) llvm/test/CodeGen/AArch64/vecreduce-and-legalization.ll (+3-3)
  • (modified) llvm/test/CodeGen/AArch64/vecreduce-bool.ll (+671-48)
  • (modified) llvm/test/CodeGen/AArch64/vecreduce-umax-legalization.ll (+2-3)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 0814380b188485..0f1995273225a1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -15820,11 +15820,26 @@ static SDValue getVectorBitwiseReduce(unsigned Opcode, SDValue Vec, EVT VT,
       return getVectorBitwiseReduce(Opcode, HalfVec, VT, DL, DAG);
     }
 
-    // Vectors that are less than 64 bits get widened to neatly fit a 64 bit
-    // register, so e.g. <4 x i1> gets lowered to <4 x i16>. Sign extending to
+    // Results of setcc operations get widened to 128 bits for xor reduce if
+    // their input operands are 128 bits wide, otherwise vectors that are less
+    // than 64 bits get widened to neatly fit a 64 bit register, so e.g.
+    // <4 x i1> gets lowered to either <4 x i16> or <4 x i32>. Sign extending to
     // this element size leads to the best codegen, since e.g. setcc results
     // might need to be truncated otherwise.
-    EVT ExtendedVT = MVT::getIntegerVT(std::max(64u / NumElems, 8u));
+    unsigned ExtendedWidth = 64;
+    if (ScalarOpcode == ISD::XOR && Vec.getOpcode() == ISD::SETCC &&
+        Vec.getOperand(0).getValueSizeInBits() >= 128) {
+      ExtendedWidth = 128;
+    }
+    EVT ExtendedVT = MVT::getIntegerVT(std::max(ExtendedWidth / NumElems, 8u));
+
+    // Negate the reduced vector value for reduce and operations that use
+    // fcmp.
+    if (ScalarOpcode == ISD::AND && NumElems < 16) {
+      Vec = DAG.getNode(
+          ISD::XOR, DL, VecVT, Vec,
+          DAG.getSplatVector(VecVT, DL, DAG.getConstant(-1, DL, MVT::i32)));
+    }
 
     // any_ext doesn't work with umin/umax, so only use it for uadd.
     unsigned ExtendOp =
@@ -15833,10 +15848,36 @@ static SDValue getVectorBitwiseReduce(unsigned Opcode, SDValue Vec, EVT VT,
         ExtendOp, DL, VecVT.changeVectorElementType(ExtendedVT), Vec);
     switch (ScalarOpcode) {
     case ISD::AND:
-      Result = DAG.getNode(ISD::VECREDUCE_UMIN, DL, ExtendedVT, Extended);
+      if (NumElems < 16) {
+        // Check if all lanes of the negated bool vector value are zero by
+        // comparing against 0.0 with ordered and equal predicate. The only
+        // non-zero bit pattern that compares ordered and equal to 0.0 is -0.0,
+        // where only the sign bit is set. However the bool vector is
+        // sign-extended so that each bit in a lane is either zero or one,
+        // meaning that it is impossible to get the bit pattern of -0.0.
+        assert(Extended.getValueSizeInBits() == 64);
+        Extended = DAG.getBitcast(MVT::f64, Extended);
+        Result = DAG.getNode(ISD::SETCC, DL, MVT::i32, Extended,
+                             DAG.getConstantFP(0.0, DL, MVT::f64),
+                             DAG.getCondCode(ISD::CondCode::SETOEQ));
+      } else {
+        Result = DAG.getNode(ISD::VECREDUCE_UMIN, DL, ExtendedVT, Extended);
+      }
       break;
     case ISD::OR:
-      Result = DAG.getNode(ISD::VECREDUCE_UMAX, DL, ExtendedVT, Extended);
+      if (NumElems < 16) {
+        // Check if any lane of the bool vector is set by comparing against 0.0.
+        // NaN bit patterns are handled by using the 'unordered or not equal'
+        // predicate. Similarly to the reduce and case, -0.0 doesn't have to be
+        // handled here (see explanation above).
+        assert(Extended.getValueSizeInBits() == 64);
+        Extended = DAG.getBitcast(MVT::f64, Extended);
+        Result = DAG.getNode(ISD::SETCC, DL, MVT::i32, Extended,
+                             DAG.getConstantFP(0.0, DL, MVT::f64),
+                             DAG.getCondCode(ISD::CondCode::SETUNE));
+      } else {
+        Result = DAG.getNode(ISD::VECREDUCE_UMAX, DL, ExtendedVT, Extended);
+      }
       break;
     case ISD::XOR:
       Result = DAG.getNode(ISD::VECREDUCE_ADD, DL, ExtendedVT, Extended);
diff --git a/llvm/test/CodeGen/AArch64/dag-combine-setcc.ll b/llvm/test/CodeGen/AArch64/dag-combine-setcc.ll
index a48a4e0e723ebc..fb366564723db6 100644
--- a/llvm/test/CodeGen/AArch64/dag-combine-setcc.ll
+++ b/llvm/test/CodeGen/AArch64/dag-combine-setcc.ll
@@ -5,10 +5,8 @@ define i1 @combine_setcc_eq_vecreduce_or_v8i1(<8 x i8> %a) {
 ; CHECK-LABEL: combine_setcc_eq_vecreduce_or_v8i1:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    cmeq v0.8b, v0.8b, #0
-; CHECK-NEXT:    mov w8, #1 // =0x1
-; CHECK-NEXT:    umaxv b0, v0.8b
-; CHECK-NEXT:    fmov w9, s0
-; CHECK-NEXT:    bic w0, w8, w9
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
   %cmp1 = icmp eq <8 x i8> %a, zeroinitializer
   %cast = bitcast <8 x i1> %cmp1 to i8
@@ -73,9 +71,8 @@ define i1 @combine_setcc_ne_vecreduce_or_v8i1(<8 x i8> %a) {
 ; CHECK-LABEL: combine_setcc_ne_vecreduce_or_v8i1:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    cmtst v0.8b, v0.8b, v0.8b
-; CHECK-NEXT:    umaxv b0, v0.8b
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, ne
 ; CHECK-NEXT:    ret
   %cmp1 = icmp ne <8 x i8> %a, zeroinitializer
   %cast = bitcast <8 x i1> %cmp1 to i8
@@ -132,10 +129,9 @@ define i1 @combine_setcc_ne_vecreduce_or_v64i1(<64 x i8> %a) {
 define i1 @combine_setcc_eq_vecreduce_and_v8i1(<8 x i8> %a) {
 ; CHECK-LABEL: combine_setcc_eq_vecreduce_and_v8i1:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    cmeq v0.8b, v0.8b, #0
-; CHECK-NEXT:    uminv b0, v0.8b
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    cmtst v0.8b, v0.8b, v0.8b
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
   %cmp1 = icmp eq <8 x i8> %a, zeroinitializer
   %cast = bitcast <8 x i1> %cmp1 to i8
@@ -192,11 +188,9 @@ define i1 @combine_setcc_eq_vecreduce_and_v64i1(<64 x i8> %a) {
 define i1 @combine_setcc_ne_vecreduce_and_v8i1(<8 x i8> %a) {
 ; CHECK-LABEL: combine_setcc_ne_vecreduce_and_v8i1:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    cmtst v0.8b, v0.8b, v0.8b
-; CHECK-NEXT:    mov w8, #1 // =0x1
-; CHECK-NEXT:    uminv b0, v0.8b
-; CHECK-NEXT:    fmov w9, s0
-; CHECK-NEXT:    bic w0, w8, w9
+; CHECK-NEXT:    cmeq v0.8b, v0.8b, #0
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, ne
 ; CHECK-NEXT:    ret
   %cmp1 = icmp ne <8 x i8> %a, zeroinitializer
   %cast = bitcast <8 x i1> %cmp1 to i8
diff --git a/llvm/test/CodeGen/AArch64/illegal-floating-point-vector-compares.ll b/llvm/test/CodeGen/AArch64/illegal-floating-point-vector-compares.ll
index 767ca91a58bb10..5374d4823034ff 100644
--- a/llvm/test/CodeGen/AArch64/illegal-floating-point-vector-compares.ll
+++ b/llvm/test/CodeGen/AArch64/illegal-floating-point-vector-compares.ll
@@ -9,13 +9,11 @@ define i1 @unordered_floating_point_compare_on_v8f32(<8 x float> %a_vec) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    fcmgt v1.4s, v1.4s, #0.0
 ; CHECK-NEXT:    fcmgt v0.4s, v0.4s, #0.0
-; CHECK-NEXT:    mov w8, #1 // =0x1
 ; CHECK-NEXT:    uzp1 v0.8h, v0.8h, v1.8h
 ; CHECK-NEXT:    mvn v0.16b, v0.16b
 ; CHECK-NEXT:    xtn v0.8b, v0.8h
-; CHECK-NEXT:    umaxv b0, v0.8b
-; CHECK-NEXT:    fmov w9, s0
-; CHECK-NEXT:    bic w0, w8, w9
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
   %a_cmp = fcmp ule <8 x float> %a_vec, zeroinitializer
   %cmp_result = bitcast <8 x i1> %a_cmp to i8
diff --git a/llvm/test/CodeGen/AArch64/reduce-and.ll b/llvm/test/CodeGen/AArch64/reduce-and.ll
index 8ca521327c2e31..62f3e8d184d24d 100644
--- a/llvm/test/CodeGen/AArch64/reduce-and.ll
+++ b/llvm/test/CodeGen/AArch64/reduce-and.ll
@@ -20,11 +20,11 @@ define i1 @test_redand_v1i1(<1 x i1> %a) {
 define i1 @test_redand_v2i1(<2 x i1> %a) {
 ; CHECK-LABEL: test_redand_v2i1:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v0.8b, v0.8b
 ; CHECK-NEXT:    shl v0.2s, v0.2s, #31
 ; CHECK-NEXT:    cmlt v0.2s, v0.2s, #0
-; CHECK-NEXT:    uminp v0.2s, v0.2s, v0.2s
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: test_redand_v2i1:
@@ -42,11 +42,11 @@ define i1 @test_redand_v2i1(<2 x i1> %a) {
 define i1 @test_redand_v4i1(<4 x i1> %a) {
 ; CHECK-LABEL: test_redand_v4i1:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v0.8b, v0.8b
 ; CHECK-NEXT:    shl v0.4h, v0.4h, #15
 ; CHECK-NEXT:    cmlt v0.4h, v0.4h, #0
-; CHECK-NEXT:    uminv h0, v0.4h
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: test_redand_v4i1:
@@ -68,11 +68,11 @@ define i1 @test_redand_v4i1(<4 x i1> %a) {
 define i1 @test_redand_v8i1(<8 x i1> %a) {
 ; CHECK-LABEL: test_redand_v8i1:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v0.8b, v0.8b
 ; CHECK-NEXT:    shl v0.8b, v0.8b, #7
 ; CHECK-NEXT:    cmlt v0.8b, v0.8b, #0
-; CHECK-NEXT:    uminv b0, v0.8b
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: test_redand_v8i1:
diff --git a/llvm/test/CodeGen/AArch64/reduce-or.ll b/llvm/test/CodeGen/AArch64/reduce-or.ll
index aac31ce8b71b75..485cb7c916140c 100644
--- a/llvm/test/CodeGen/AArch64/reduce-or.ll
+++ b/llvm/test/CodeGen/AArch64/reduce-or.ll
@@ -22,9 +22,8 @@ define i1 @test_redor_v2i1(<2 x i1> %a) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    shl v0.2s, v0.2s, #31
 ; CHECK-NEXT:    cmlt v0.2s, v0.2s, #0
-; CHECK-NEXT:    umaxp v0.2s, v0.2s, v0.2s
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, ne
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: test_redor_v2i1:
@@ -44,9 +43,8 @@ define i1 @test_redor_v4i1(<4 x i1> %a) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    shl v0.4h, v0.4h, #15
 ; CHECK-NEXT:    cmlt v0.4h, v0.4h, #0
-; CHECK-NEXT:    umaxv h0, v0.4h
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, ne
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: test_redor_v4i1:
@@ -70,9 +68,8 @@ define i1 @test_redor_v8i1(<8 x i1> %a) {
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    shl v0.8b, v0.8b, #7
 ; CHECK-NEXT:    cmlt v0.8b, v0.8b, #0
-; CHECK-NEXT:    umaxv b0, v0.8b
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, ne
 ; CHECK-NEXT:    ret
 ;
 ; GISEL-LABEL: test_redor_v8i1:
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-and-legalization.ll b/llvm/test/CodeGen/AArch64/vecreduce-and-legalization.ll
index 7fa416e0dbcd5c..fd81deeb7d913b 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-and-legalization.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-and-legalization.ll
@@ -139,11 +139,11 @@ define i32 @test_v3i32(<3 x i32> %a) nounwind {
 define i1 @test_v4i1(<4 x i1> %a) nounwind {
 ; CHECK-LABEL: test_v4i1:
 ; CHECK:       // %bb.0:
+; CHECK-NEXT:    mvn v0.8b, v0.8b
 ; CHECK-NEXT:    shl v0.4h, v0.4h, #15
 ; CHECK-NEXT:    cmlt v0.4h, v0.4h, #0
-; CHECK-NEXT:    uminv h0, v0.4h
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    and w0, w8, #0x1
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    cset w0, eq
 ; CHECK-NEXT:    ret
   %b = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> %a)
   ret i1 %b
diff --git a/llvm/test/CodeGen/AArch64/vecreduce-bool.ll b/llvm/test/CodeGen/AArch64/vecreduce-bool.ll
index 58020d28702b2f..10a3ef1658a965 100644
--- a/llvm/test/CodeGen/AArch64/vecreduce-bool.ll
+++ b/llvm/test/CodeGen/AArch64/vecreduce-bool.ll
@@ -15,8 +15,15 @@ declare i1 @llvm.vector.reduce.or.v8i1(<8 x i1> %a)
 declare i1 @llvm.vector.reduce.or.v16i1(<16 x i1> %a)
 declare i1 @llvm.vector.reduce.or.v32i1(<32 x i1> %a)
 
-define i32 @reduce_and_v1(<1 x i8> %a0, i32 %a1, i32 %a2) nounwind {
-; CHECK-LABEL: reduce_and_v1:
+declare i1 @llvm.vector.reduce.xor.v1i1(<1 x i1> %a)
+declare i1 @llvm.vector.reduce.xor.v2i1(<2 x i1> %a)
+declare i1 @llvm.vector.reduce.xor.v4i1(<4 x i1> %a)
+declare i1 @llvm.vector.reduce.xor.v8i1(<8 x i1> %a)
+declare i1 @llvm.vector.reduce.xor.v16i1(<16 x i1> %a)
+declare i1 @llvm.vector.reduce.xor.v32i1(<32 x i1> %a)
+
+define i32 @reduce_and_v1i8(<1 x i8> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v1i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
 ; CHECK-NEXT:    smov w8, v0.b[0]
@@ -29,16 +36,14 @@ define i32 @reduce_and_v1(<1 x i8> %a0, i32 %a1, i32 %a2) nounwind {
   ret i32 %z
 }
 
-define i32 @reduce_and_v2(<2 x i8> %a0, i32 %a1, i32 %a2) nounwind {
-; CHECK-LABEL: reduce_and_v2:
+define i32 @reduce_and_v2i8(<2 x i8> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v2i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    shl v0.2s, v0.2s, #24
 ; CHECK-NEXT:    sshr v0.2s, v0.2s, #24
-; CHECK-NEXT:    cmlt v0.2s, v0.2s, #0
-; CHECK-NEXT:    uminp v0.2s, v0.2s, v0.2s
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    tst w8, #0x1
-; CHECK-NEXT:    csel w0, w0, w1, ne
+; CHECK-NEXT:    cmge v0.2s, v0.2s, #0
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
 ; CHECK-NEXT:    ret
   %x = icmp slt <2 x i8> %a0, zeroinitializer
   %y = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %x)
@@ -46,16 +51,14 @@ define i32 @reduce_and_v2(<2 x i8> %a0, i32 %a1, i32 %a2) nounwind {
   ret i32 %z
 }
 
-define i32 @reduce_and_v4(<4 x i8> %a0, i32 %a1, i32 %a2) nounwind {
-; CHECK-LABEL: reduce_and_v4:
+define i32 @reduce_and_v4i8(<4 x i8> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v4i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    shl v0.4h, v0.4h, #8
 ; CHECK-NEXT:    sshr v0.4h, v0.4h, #8
-; CHECK-NEXT:    cmlt v0.4h, v0.4h, #0
-; CHECK-NEXT:    uminv h0, v0.4h
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    tst w8, #0x1
-; CHECK-NEXT:    csel w0, w0, w1, ne
+; CHECK-NEXT:    cmge v0.4h, v0.4h, #0
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
 ; CHECK-NEXT:    ret
   %x = icmp slt <4 x i8> %a0, zeroinitializer
   %y = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> %x)
@@ -63,14 +66,12 @@ define i32 @reduce_and_v4(<4 x i8> %a0, i32 %a1, i32 %a2) nounwind {
   ret i32 %z
 }
 
-define i32 @reduce_and_v8(<8 x i8> %a0, i32 %a1, i32 %a2) nounwind {
-; CHECK-LABEL: reduce_and_v8:
+define i32 @reduce_and_v8i8(<8 x i8> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v8i8:
 ; CHECK:       // %bb.0:
-; CHECK-NEXT:    cmlt v0.8b, v0.8b, #0
-; CHECK-NEXT:    uminv b0, v0.8b
-; CHECK-NEXT:    fmov w8, s0
-; CHECK-NEXT:    tst w8, #0x1
-; CHECK-NEXT:    csel w0, w0, w1, ne
+; CHECK-NEXT:    cmge v0.8b, v0.8b, #0
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
 ; CHECK-NEXT:    ret
   %x = icmp slt <8 x i8> %a0, zeroinitializer
   %y = call i1 @llvm.vector.reduce.and.v8i1(<8 x i1> %x)
@@ -78,8 +79,8 @@ define i32 @reduce_and_v8(<8 x i8> %a0, i32 %a1, i32 %a2) nounwind {
   ret i32 %z
 }
 
-define i32 @reduce_and_v16(<16 x i8> %a0, i32 %a1, i32 %a2) nounwind {
-; CHECK-LABEL: reduce_and_v16:
+define i32 @reduce_and_v16i8(<16 x i8> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v16i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    cmlt v0.16b, v0.16b, #0
 ; CHECK-NEXT:    uminv b0, v0.16b
@@ -93,8 +94,8 @@ define i32 @reduce_and_v16(<16 x i8> %a0, i32 %a1, i32 %a2) nounwind {
   ret i32 %z
 }
 
-define i32 @reduce_and_v32(<32 x i8> %a0, i32 %a1, i32 %a2) nounwind {
-; CHECK-LABEL: reduce_and_v32:
+define i32 @reduce_and_v32i8(<32 x i8> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v32i8:
 ; CHECK:       // %bb.0:
 ; CHECK-NEXT:    and v0.16b, v0.16b, v1.16b
 ; CHECK-NEXT:    cmlt v0.16b, v0.16b, #0
@@ -109,8 +110,182 @@ define i32 @reduce_and_v32(<32 x i8> %a0, i32 %a1, i32 %a2) nounwind {
   ret i32 %z
 }
 
-define i32 @reduce_or_v1(<1 x i8> %a0, i32 %a1, i32 %a2) nounwind {
-; CHECK-LABEL: reduce_or_v1:
+define i32 @reduce_and_v1i16(<1 x i16> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v1i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    smov w8, v0.h[0]
+; CHECK-NEXT:    cmp w8, #0
+; CHECK-NEXT:    csel w0, w0, w1, lt
+; CHECK-NEXT:    ret
+  %x = icmp slt <1 x i16> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v1i1(<1 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v2i16(<2 x i16> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v2i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    shl v0.2s, v0.2s, #16
+; CHECK-NEXT:    sshr v0.2s, v0.2s, #16
+; CHECK-NEXT:    cmge v0.2s, v0.2s, #0
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
+; CHECK-NEXT:    ret
+  %x = icmp slt <2 x i16> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v4i16(<4 x i16> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v4i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmge v0.4h, v0.4h, #0
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
+; CHECK-NEXT:    ret
+  %x = icmp slt <4 x i16> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v8i16(<8 x i16> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v8i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmge v0.8h, v0.8h, #0
+; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
+; CHECK-NEXT:    ret
+  %x = icmp slt <8 x i16> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v8i1(<8 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v16i16(<16 x i16> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v16i16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmlt v1.8h, v1.8h, #0
+; CHECK-NEXT:    cmlt v0.8h, v0.8h, #0
+; CHECK-NEXT:    uzp1 v0.16b, v0.16b, v1.16b
+; CHECK-NEXT:    uminv b0, v0.16b
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    tst w8, #0x1
+; CHECK-NEXT:    csel w0, w0, w1, ne
+; CHECK-NEXT:    ret
+  %x = icmp slt <16 x i16> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v16i1(<16 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v1i32(<1 x i32> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v1i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    fmov w8, s0
+; CHECK-NEXT:    cmp w8, #0
+; CHECK-NEXT:    csel w0, w0, w1, lt
+; CHECK-NEXT:    ret
+  %x = icmp slt <1 x i32> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v1i1(<1 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v2i32(<2 x i32> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v2i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmge v0.2s, v0.2s, #0
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
+; CHECK-NEXT:    ret
+  %x = icmp slt <2 x i32> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v4i32(<4 x i32> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v4i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmge v0.4s, v0.4s, #0
+; CHECK-NEXT:    xtn v0.4h, v0.4s
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
+; CHECK-NEXT:    ret
+  %x = icmp slt <4 x i32> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v8i32(<8 x i32> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v8i32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    cmge v1.4s, v1.4s, #0
+; CHECK-NEXT:    cmge v0.4s, v0.4s, #0
+; CHECK-NEXT:    uzp1 v0.8h, v0.8h, v1.8h
+; CHECK-NEXT:    xtn v0.8b, v0.8h
+; CHECK-NEXT:    fcmp d0, #0.0
+; CHECK-NEXT:    csel w0, w0, w1, eq
+; CHECK-NEXT:    ret
+  %x = icmp slt <8 x i32> %a0, zeroinitializer
+  %y = call i1 @llvm.vector.reduce.and.v8i1(<8 x i1> %x)
+  %z = select i1 %y, i32 %a1, i32 %a2
+  ret i32 %z
+}
+
+define i32 @reduce_and_v1i64(<1 x i64> %a0, i32 %a1, i32 %a2) nounwind {
+; CHECK-LABEL: reduce_and_v1i64:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    fmov x8...
[truncated]

Copy link
Contributor

@lawben lawben left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGMT. Using the floating point compare seems like a neat idea. But you should probably wait for the approval of another reviewer, as I'm not sure if I'm missing an edge case or not.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. I hadn't realized it was an i1 not.

LGTM, thanks.

; CHECK-NEXT: uminp v0.2s, v0.2s, v0.2s
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the lowering is quite clever here, but is there now an issue with serialisation if you have multiple reductions? Suppose your IR looks like this:

  %or_result1 = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %a)
  %or_result2 = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %b)
  %or_result = or i1 %or_result1, %or_result2
  ret i1 %or_result

I haven't checked the code with and without this patch, but I imagine previously we could quite happily have interleaved the instructions like this:

  uminp v0 ...
  uminp v1 ...
  fmov w8, s0
  fmov w9, s1
  and w0, w8, 0x1
  and w1, w9, 0x1
  or w0, w0, w1

whereas now due to the single CC register we have to serialise:

  fcmp d0, #0.0
  cset w0, eq
  fcmp d1, #0.0
  cset w1, eq
  or w0, w0, w1
...

However, I can see how this new version is efficient if the result is then used for control flow:

  %or_result = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %a)
  br i1 %or_result, ...

Do you have any examples showing where this patch helps improve performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance differences I originally tested a simple loop similar to this:

void test(bool *dest, float32x4_t *p, std::size_t n) {
    for (std::size_t i = 0; i < n; ++i) {
        dest[i] = __builtin_reduce_or(p[i] < 0.0);
    }
}

My change typically resulted in a 10-15% improvement on various CPUs.
My original motivating use case was something like this:

float32x4_t x = ...;
// ...
x = f(x);
if (__builtin_reduce_or(x < 0.0)) return;
x = g(x);
if (__builtin_reduce_or(x < 0.0)) return;
// ...

which should benefit a bit more from this change, since the reduction result is used for control flow.

For the case of or(reduce_and(x), reduce_and(y)), you make a good point. Currently LLVM generates this:

        uminv   b0, v0.8b
        uminv   b1, v1.8b
        fmov    w8, s0
        fmov    w9, s1
        orr     w8, w8, w9
        and     w0, w8, #0x1
        ret

With my change LLVM generates something like this:

        fcmp d0, #0.0
        cset w8, eq
        fcmp d1, #0.0
        csinc w0, w8, wzr, ne
        ret

In this example both snippets have a max dependency length of 4, but if we were to or together more reduce_and operations, the generated code would get worse after this change, which can be a concern. Although I'm not sure how common such a pattern is in real code.

Similar patterns like and(reduce_and(x), reduce_and(y)) get folded to reduce_and(and(x, y)) by instcombine, so that shouldn't be an issue.

I guess the question is how common the pattern you bring up is in real-world code, and does the potential regression in that case outweigh the improvements in other cases. I don't really have a good answer for this though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! And just to be clear I'm not saying this will definitely be an issue and I won't hold up the patch for it. I was just curious to see the motivating examples, that's all.

@Il-Capitano
Copy link
Contributor Author

Are there any other thoughts on this, or can we go ahead with the merge?

@davemgreen
Copy link
Collaborator

Sorry - We keep forgetting who has commit access. I'll try and remember in the future.
Can you rebase and check the tests are still OK? Some seem to have changed. Thanks

* Avoid unnecessary truncation of comparison results in vecreduce_xor
* Optimize generated code for vecreduce_and and vecreduce_or by
  comparing against 0.0 to check if all/any of the values are set

Alive2 proof of vecreduce_and and vecreduce_or transformation: https://alive2.llvm.org/ce/z/SRfPtw
Required changes:
* Use APInt::getAllOnes() instead of using -1 in DAG.getConstant()
* A new test case was added that is affected by this change
@Il-Capitano Il-Capitano force-pushed the optimize-vecreduce-bool branch from 0ed5fa7 to 3097021 Compare December 10, 2024 10:54
@Il-Capitano
Copy link
Contributor Author

Thanks for the quick reaction.

Two changes were required after the rebase:

@davemgreen
Copy link
Collaborator

Thanks!

@davemgreen davemgreen merged commit 97ff961 into llvm:main Dec 10, 2024
8 checks passed
Vec = DAG.getNode(
ISD::XOR, DL, VecVT, Vec,
DAG.getSplatVector(
VecVT, DL, DAG.getConstant(APInt::getAllOnes(32), DL, MVT::i32)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI there is a DAG.getAllOnesConstant(DL, MVT::i32) helper for this.

@Il-Capitano
Copy link
Contributor Author

@davemgreen @nikic
It looks like this breaks some things when compiling with -Ofast: https://linaro.atlassian.net/browse/LLVM-1481

I was able to reproduce the issue with this test case: https://github.com/fujitsu/compiler-test-suite/blob/main/C/0030/0030_0066.c

The issue seems to be triggered during LTO. If I compile the test case to assembly, there's no difference between -O3 and -Ofast.

@davemgreen
Copy link
Collaborator

I see, with fast math.. I'll revert this for now if that is OK.

I think it might be conflicting with DAZ/FTZ that gets set by fast-math. The start-up object sets the flags in FPCR.

davemgreen added a commit that referenced this pull request Dec 11, 2024
…tions (#115713)"

This reverts commit 97ff961 as it conflicts
with fast-math and possible denormal flushing.
@Il-Capitano
Copy link
Contributor Author

Il-Capitano commented Dec 12, 2024

You're right, the issue is that the fast-math start-up object sets the FZ bit (bit 24) of FPCR, meaning that denormalized single-precision and double-precision inputs to, and outputs from, floating-point instructions are flushed to zero. This means that we can't reliably use fcmp d0, #0.0 to check for an all-zero bit pattern.

To me it seems there's no way to use the fcmp trick and get correct lowering in all cases, since the setting of FPCR is configured at link time, i.e. a TU compiled with -O3 can be affected if the -mdaz-ftz flag is used in the linking command.

There's one part of this patch that could still be used to optimize boolean reductions (used for reduce_xor here). I'll open a new PR for that in the next few days.

For reference, here's a C++ example that lowers to the same assembly, and shows the different behaviour with and without -mdaz-ftz: https://godbolt.org/z/Ee6MKPjjx

@Il-Capitano Il-Capitano deleted the optimize-vecreduce-bool branch December 12, 2024 13:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants