Skip to content

[AArch64] Improve code generation of bool vector reduce operations #115713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15841,11 +15841,27 @@ static SDValue getVectorBitwiseReduce(unsigned Opcode, SDValue Vec, EVT VT,
return getVectorBitwiseReduce(Opcode, HalfVec, VT, DL, DAG);
}

// Vectors that are less than 64 bits get widened to neatly fit a 64 bit
// register, so e.g. <4 x i1> gets lowered to <4 x i16>. Sign extending to
// Results of setcc operations get widened to 128 bits for xor reduce if
// their input operands are 128 bits wide, otherwise vectors that are less
// than 64 bits get widened to neatly fit a 64 bit register, so e.g.
// <4 x i1> gets lowered to either <4 x i16> or <4 x i32>. Sign extending to
// this element size leads to the best codegen, since e.g. setcc results
// might need to be truncated otherwise.
EVT ExtendedVT = MVT::getIntegerVT(std::max(64u / NumElems, 8u));
unsigned ExtendedWidth = 64;
if (ScalarOpcode == ISD::XOR && Vec.getOpcode() == ISD::SETCC &&
Vec.getOperand(0).getValueSizeInBits() >= 128) {
ExtendedWidth = 128;
}
EVT ExtendedVT = MVT::getIntegerVT(std::max(ExtendedWidth / NumElems, 8u));

// Negate the reduced vector value for reduce and operations that use
// fcmp.
if (ScalarOpcode == ISD::AND && NumElems < 16) {
Vec = DAG.getNode(
ISD::XOR, DL, VecVT, Vec,
DAG.getSplatVector(
VecVT, DL, DAG.getConstant(APInt::getAllOnes(32), DL, MVT::i32)));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI there is a DAG.getAllOnesConstant(DL, MVT::i32) helper for this.

}

// any_ext doesn't work with umin/umax, so only use it for uadd.
unsigned ExtendOp =
Expand All @@ -15854,10 +15870,36 @@ static SDValue getVectorBitwiseReduce(unsigned Opcode, SDValue Vec, EVT VT,
ExtendOp, DL, VecVT.changeVectorElementType(ExtendedVT), Vec);
switch (ScalarOpcode) {
case ISD::AND:
Result = DAG.getNode(ISD::VECREDUCE_UMIN, DL, ExtendedVT, Extended);
if (NumElems < 16) {
// Check if all lanes of the negated bool vector value are zero by
// comparing against 0.0 with ordered and equal predicate. The only
// non-zero bit pattern that compares ordered and equal to 0.0 is -0.0,
// where only the sign bit is set. However the bool vector is
// sign-extended so that each bit in a lane is either zero or one,
// meaning that it is impossible to get the bit pattern of -0.0.
assert(Extended.getValueSizeInBits() == 64);
Extended = DAG.getBitcast(MVT::f64, Extended);
Result =
DAG.getSetCC(DL, MVT::i32, Extended,
DAG.getConstantFP(0.0, DL, MVT::f64), ISD::SETOEQ);
} else {
Result = DAG.getNode(ISD::VECREDUCE_UMIN, DL, ExtendedVT, Extended);
}
break;
case ISD::OR:
Result = DAG.getNode(ISD::VECREDUCE_UMAX, DL, ExtendedVT, Extended);
if (NumElems < 16) {
// Check if any lane of the bool vector is set by comparing against 0.0.
// NaN bit patterns are handled by using the 'unordered or not equal'
// predicate. Similarly to the reduce and case, -0.0 doesn't have to be
// handled here (see explanation above).
assert(Extended.getValueSizeInBits() == 64);
Extended = DAG.getBitcast(MVT::f64, Extended);
Result =
DAG.getSetCC(DL, MVT::i32, Extended,
DAG.getConstantFP(0.0, DL, MVT::f64), ISD::SETUNE);
} else {
Result = DAG.getNode(ISD::VECREDUCE_UMAX, DL, ExtendedVT, Extended);
}
break;
case ISD::XOR:
Result = DAG.getNode(ISD::VECREDUCE_ADD, DL, ExtendedVT, Extended);
Expand Down
26 changes: 10 additions & 16 deletions llvm/test/CodeGen/AArch64/dag-combine-setcc.ll
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@ define i1 @combine_setcc_eq_vecreduce_or_v8i1(<8 x i8> %a) {
; CHECK-LABEL: combine_setcc_eq_vecreduce_or_v8i1:
; CHECK: // %bb.0:
; CHECK-NEXT: cmeq v0.8b, v0.8b, #0
; CHECK-NEXT: mov w8, #1 // =0x1
; CHECK-NEXT: umaxv b0, v0.8b
; CHECK-NEXT: fmov w9, s0
; CHECK-NEXT: bic w0, w8, w9
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
%cmp1 = icmp eq <8 x i8> %a, zeroinitializer
%cast = bitcast <8 x i1> %cmp1 to i8
Expand Down Expand Up @@ -73,9 +71,8 @@ define i1 @combine_setcc_ne_vecreduce_or_v8i1(<8 x i8> %a) {
; CHECK-LABEL: combine_setcc_ne_vecreduce_or_v8i1:
; CHECK: // %bb.0:
; CHECK-NEXT: cmtst v0.8b, v0.8b, v0.8b
; CHECK-NEXT: umaxv b0, v0.8b
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
%cmp1 = icmp ne <8 x i8> %a, zeroinitializer
%cast = bitcast <8 x i1> %cmp1 to i8
Expand Down Expand Up @@ -132,10 +129,9 @@ define i1 @combine_setcc_ne_vecreduce_or_v64i1(<64 x i8> %a) {
define i1 @combine_setcc_eq_vecreduce_and_v8i1(<8 x i8> %a) {
; CHECK-LABEL: combine_setcc_eq_vecreduce_and_v8i1:
; CHECK: // %bb.0:
; CHECK-NEXT: cmeq v0.8b, v0.8b, #0
; CHECK-NEXT: uminv b0, v0.8b
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: cmtst v0.8b, v0.8b, v0.8b
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
%cmp1 = icmp eq <8 x i8> %a, zeroinitializer
%cast = bitcast <8 x i1> %cmp1 to i8
Expand Down Expand Up @@ -192,11 +188,9 @@ define i1 @combine_setcc_eq_vecreduce_and_v64i1(<64 x i8> %a) {
define i1 @combine_setcc_ne_vecreduce_and_v8i1(<8 x i8> %a) {
; CHECK-LABEL: combine_setcc_ne_vecreduce_and_v8i1:
; CHECK: // %bb.0:
; CHECK-NEXT: cmtst v0.8b, v0.8b, v0.8b
; CHECK-NEXT: mov w8, #1 // =0x1
; CHECK-NEXT: uminv b0, v0.8b
; CHECK-NEXT: fmov w9, s0
; CHECK-NEXT: bic w0, w8, w9
; CHECK-NEXT: cmeq v0.8b, v0.8b, #0
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
%cmp1 = icmp ne <8 x i8> %a, zeroinitializer
%cast = bitcast <8 x i1> %cmp1 to i8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@ define i1 @unordered_floating_point_compare_on_v8f32(<8 x float> %a_vec) {
; CHECK: // %bb.0:
; CHECK-NEXT: fcmgt v1.4s, v1.4s, #0.0
; CHECK-NEXT: fcmgt v0.4s, v0.4s, #0.0
; CHECK-NEXT: mov w8, #1 // =0x1
; CHECK-NEXT: uzp1 v0.8h, v0.8h, v1.8h
; CHECK-NEXT: mvn v0.16b, v0.16b
; CHECK-NEXT: xtn v0.8b, v0.8h
; CHECK-NEXT: umaxv b0, v0.8b
; CHECK-NEXT: fmov w9, s0
; CHECK-NEXT: bic w0, w8, w9
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
%a_cmp = fcmp ule <8 x float> %a_vec, zeroinitializer
%cmp_result = bitcast <8 x i1> %a_cmp to i8
Expand Down
18 changes: 9 additions & 9 deletions llvm/test/CodeGen/AArch64/reduce-and.ll
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ define i1 @test_redand_v1i1(<1 x i1> %a) {
define i1 @test_redand_v2i1(<2 x i1> %a) {
; CHECK-LABEL: test_redand_v2i1:
; CHECK: // %bb.0:
; CHECK-NEXT: mvn v0.8b, v0.8b
; CHECK-NEXT: shl v0.2s, v0.2s, #31
; CHECK-NEXT: cmlt v0.2s, v0.2s, #0
; CHECK-NEXT: uminp v0.2s, v0.2s, v0.2s
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the lowering is quite clever here, but is there now an issue with serialisation if you have multiple reductions? Suppose your IR looks like this:

  %or_result1 = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %a)
  %or_result2 = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %b)
  %or_result = or i1 %or_result1, %or_result2
  ret i1 %or_result

I haven't checked the code with and without this patch, but I imagine previously we could quite happily have interleaved the instructions like this:

  uminp v0 ...
  uminp v1 ...
  fmov w8, s0
  fmov w9, s1
  and w0, w8, 0x1
  and w1, w9, 0x1
  or w0, w0, w1

whereas now due to the single CC register we have to serialise:

  fcmp d0, #0.0
  cset w0, eq
  fcmp d1, #0.0
  cset w1, eq
  or w0, w0, w1
...

However, I can see how this new version is efficient if the result is then used for control flow:

  %or_result = call i1 @llvm.vector.reduce.and.v2i1(<2 x i1> %a)
  br i1 %or_result, ...

Do you have any examples showing where this patch helps improve performance?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance differences I originally tested a simple loop similar to this:

void test(bool *dest, float32x4_t *p, std::size_t n) {
    for (std::size_t i = 0; i < n; ++i) {
        dest[i] = __builtin_reduce_or(p[i] < 0.0);
    }
}

My change typically resulted in a 10-15% improvement on various CPUs.
My original motivating use case was something like this:

float32x4_t x = ...;
// ...
x = f(x);
if (__builtin_reduce_or(x < 0.0)) return;
x = g(x);
if (__builtin_reduce_or(x < 0.0)) return;
// ...

which should benefit a bit more from this change, since the reduction result is used for control flow.

For the case of or(reduce_and(x), reduce_and(y)), you make a good point. Currently LLVM generates this:

        uminv   b0, v0.8b
        uminv   b1, v1.8b
        fmov    w8, s0
        fmov    w9, s1
        orr     w8, w8, w9
        and     w0, w8, #0x1
        ret

With my change LLVM generates something like this:

        fcmp d0, #0.0
        cset w8, eq
        fcmp d1, #0.0
        csinc w0, w8, wzr, ne
        ret

In this example both snippets have a max dependency length of 4, but if we were to or together more reduce_and operations, the generated code would get worse after this change, which can be a concern. Although I'm not sure how common such a pattern is in real code.

Similar patterns like and(reduce_and(x), reduce_and(y)) get folded to reduce_and(and(x, y)) by instcombine, so that shouldn't be an issue.

I guess the question is how common the pattern you bring up is in real-world code, and does the potential regression in that case outweigh the improvements in other cases. I don't really have a good answer for this though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining! And just to be clear I'm not saying this will definitely be an issue and I won't hold up the patch for it. I was just curious to see the motivating examples, that's all.

; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
;
; GISEL-LABEL: test_redand_v2i1:
Expand All @@ -42,11 +42,11 @@ define i1 @test_redand_v2i1(<2 x i1> %a) {
define i1 @test_redand_v4i1(<4 x i1> %a) {
; CHECK-LABEL: test_redand_v4i1:
; CHECK: // %bb.0:
; CHECK-NEXT: mvn v0.8b, v0.8b
; CHECK-NEXT: shl v0.4h, v0.4h, #15
; CHECK-NEXT: cmlt v0.4h, v0.4h, #0
; CHECK-NEXT: uminv h0, v0.4h
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
;
; GISEL-LABEL: test_redand_v4i1:
Expand All @@ -68,11 +68,11 @@ define i1 @test_redand_v4i1(<4 x i1> %a) {
define i1 @test_redand_v8i1(<8 x i1> %a) {
; CHECK-LABEL: test_redand_v8i1:
; CHECK: // %bb.0:
; CHECK-NEXT: mvn v0.8b, v0.8b
; CHECK-NEXT: shl v0.8b, v0.8b, #7
; CHECK-NEXT: cmlt v0.8b, v0.8b, #0
; CHECK-NEXT: uminv b0, v0.8b
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
;
; GISEL-LABEL: test_redand_v8i1:
Expand Down
15 changes: 6 additions & 9 deletions llvm/test/CodeGen/AArch64/reduce-or.ll
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@ define i1 @test_redor_v2i1(<2 x i1> %a) {
; CHECK: // %bb.0:
; CHECK-NEXT: shl v0.2s, v0.2s, #31
; CHECK-NEXT: cmlt v0.2s, v0.2s, #0
; CHECK-NEXT: umaxp v0.2s, v0.2s, v0.2s
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
;
; GISEL-LABEL: test_redor_v2i1:
Expand All @@ -44,9 +43,8 @@ define i1 @test_redor_v4i1(<4 x i1> %a) {
; CHECK: // %bb.0:
; CHECK-NEXT: shl v0.4h, v0.4h, #15
; CHECK-NEXT: cmlt v0.4h, v0.4h, #0
; CHECK-NEXT: umaxv h0, v0.4h
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
;
; GISEL-LABEL: test_redor_v4i1:
Expand All @@ -70,9 +68,8 @@ define i1 @test_redor_v8i1(<8 x i1> %a) {
; CHECK: // %bb.0:
; CHECK-NEXT: shl v0.8b, v0.8b, #7
; CHECK-NEXT: cmlt v0.8b, v0.8b, #0
; CHECK-NEXT: umaxv b0, v0.8b
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, ne
; CHECK-NEXT: ret
;
; GISEL-LABEL: test_redor_v8i1:
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/CodeGen/AArch64/vecreduce-and-legalization.ll
Original file line number Diff line number Diff line change
Expand Up @@ -139,11 +139,11 @@ define i32 @test_v3i32(<3 x i32> %a) nounwind {
define i1 @test_v4i1(<4 x i1> %a) nounwind {
; CHECK-LABEL: test_v4i1:
; CHECK: // %bb.0:
; CHECK-NEXT: mvn v0.8b, v0.8b
; CHECK-NEXT: shl v0.4h, v0.4h, #15
; CHECK-NEXT: cmlt v0.4h, v0.4h, #0
; CHECK-NEXT: uminv h0, v0.4h
; CHECK-NEXT: fmov w8, s0
; CHECK-NEXT: and w0, w8, #0x1
; CHECK-NEXT: fcmp d0, #0.0
; CHECK-NEXT: cset w0, eq
; CHECK-NEXT: ret
%b = call i1 @llvm.vector.reduce.and.v4i1(<4 x i1> %a)
ret i1 %b
Expand Down
Loading
Loading