Skip to content

[AArch64][SVE] Fold zero-extend into add reduction. #102325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18166,6 +18166,46 @@ static SDValue performVecReduceAddCombineWithUADDLP(SDNode *N,
return DAG.getNode(ISD::VECREDUCE_ADD, DL, MVT::i32, UADDLP);
}

// Turn vecreduce_add(zext/sext(...)) into SVE's [US]ADDV instruction.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an extra space in SVE's [US]ADDV

static SDValue
performVecReduceAddExtCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
const AArch64TargetLowering &TLI) {
if (N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND)
return SDValue();

SelectionDAG &DAG = DCI.DAG;
const auto &ST = DAG.getSubtarget<AArch64Subtarget>();
SDValue VecOp = N->getOperand(0).getOperand(0);
EVT VecOpVT = VecOp.getValueType();
if (VecOpVT.getScalarType() == MVT::i1 || !TLI.isTypeLegal(VecOpVT) ||
(VecOpVT.isFixedLengthVector() &&
!TLI.useSVEForFixedLengthVectorVT(
VecOpVT, /*OverrideNEON=*/ST.useSVEForFixedLengthVectors())))
return SDValue();

SDLoc DL(N);

// The input type is legal so map VECREDUCE_ADD to UADDV/SADDV, e.g.
// i32 (vecreduce_add (zext nxv16i8 %op to nxv16i32))
// ->
// i32 (UADDV nxv16i8:%op)
EVT ElemType = N->getValueType(0);
SDValue Pg = getPredicateForVector(DAG, DL, VecOpVT);
if (VecOpVT.isFixedLengthVector()) {
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VecOpVT);
VecOp = convertToScalableVector(DAG, ContainerVT, VecOp);
}
bool IsSigned = N->getOperand(0).getOpcode() == ISD::SIGN_EXTEND;
SDValue Res =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, MVT::i64,
DAG.getConstant(IsSigned ? Intrinsic::aarch64_sve_saddv
: Intrinsic::aarch64_sve_uaddv,
DL, MVT::i64),
Pg, VecOp);
return DAG.getAnyExtOrTrunc(Res, DL, ElemType);
}

// Turn a v8i8/v16i8 extended vecreduce into a udot/sdot and vecreduce
// vecreduce.add(ext(A)) to vecreduce.add(DOT(zero, A, one))
// vecreduce.add(mul(ext(A), ext(B))) to vecreduce.add(DOT(zero, A, B))
Expand Down Expand Up @@ -25888,8 +25928,11 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
return performInsertVectorEltCombine(N, DCI);
case ISD::EXTRACT_VECTOR_ELT:
return performExtractVectorEltCombine(N, DCI, Subtarget);
case ISD::VECREDUCE_ADD:
return performVecReduceAddCombine(N, DCI.DAG, Subtarget);
case ISD::VECREDUCE_ADD: {
if (SDValue Val = performVecReduceAddCombine(N, DCI.DAG, Subtarget))
return Val;
return performVecReduceAddExtCombine(N, DCI, *this);
}
case AArch64ISD::UADDV:
return performUADDVCombine(N, DAG);
case AArch64ISD::SMULL:
Expand Down
28 changes: 11 additions & 17 deletions llvm/test/CodeGen/AArch64/sve-doublereduct.ll
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,12 @@ define i32 @add_i32(<vscale x 8 x i32> %a, <vscale x 4 x i32> %b) {
define i16 @add_ext_i16(<vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-LABEL: add_ext_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: uunpkhi z2.h, z0.b
; CHECK-NEXT: uunpklo z0.h, z0.b
; CHECK-NEXT: uunpkhi z3.h, z1.b
; CHECK-NEXT: uunpklo z1.h, z1.b
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: add z0.h, z0.h, z2.h
; CHECK-NEXT: add z1.h, z1.h, z3.h
; CHECK-NEXT: add z0.h, z0.h, z1.h
; CHECK-NEXT: uaddv d0, p0, z0.h
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: uaddv d0, p0, z0.b
; CHECK-NEXT: uaddv d1, p0, z1.b
; CHECK-NEXT: fmov w8, s0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this patch, but I wonder if we can improve this further with:

  uaddv d0, p0, z0.b
  uaddv d1, p0, z1.b
  add v0.4s, v0.4s, v1.4s
  fmov w0, s0

The throughput of the NEON add is much higher than that of fmov and the latency is about the same.

; CHECK-NEXT: fmov w9, s1
; CHECK-NEXT: add w0, w8, w9
; CHECK-NEXT: ret
%ae = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
%be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>
Expand All @@ -130,17 +125,16 @@ define i16 @add_ext_v32i16(<vscale x 32 x i8> %a, <vscale x 16 x i8> %b) {
; CHECK-NEXT: uunpklo z4.h, z0.b
; CHECK-NEXT: uunpkhi z1.h, z1.b
; CHECK-NEXT: uunpkhi z0.h, z0.b
; CHECK-NEXT: uunpkhi z5.h, z2.b
; CHECK-NEXT: uunpklo z2.h, z2.b
; CHECK-NEXT: ptrue p0.h
; CHECK-NEXT: ptrue p1.b
; CHECK-NEXT: add z0.h, z0.h, z1.h
; CHECK-NEXT: add z1.h, z4.h, z3.h
; CHECK-NEXT: add z0.h, z1.h, z0.h
; CHECK-NEXT: add z1.h, z2.h, z5.h
; CHECK-NEXT: add z0.h, z0.h, z1.h
; CHECK-NEXT: uaddv d1, p1, z2.b
; CHECK-NEXT: uaddv d0, p0, z0.h
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
; CHECK-NEXT: fmov w9, s1
; CHECK-NEXT: fmov x8, d0
; CHECK-NEXT: add w0, w8, w9
; CHECK-NEXT: ret
%ae = zext <vscale x 32 x i8> %a to <vscale x 32 x i16>
%be = zext <vscale x 16 x i8> %b to <vscale x 16 x i16>
Expand Down
28 changes: 28 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-fixed-length-int-reduce.ll
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,34 @@ define i64 @uaddv_v32i64(ptr %a) vscale_range(16,0) #0 {
ret i64 %res
}

define i64 @uaddv_zext_v32i8_v32i64(ptr %a) vscale_range(2,0) #0 {
; CHECK-LABEL: uaddv_zext_v32i8_v32i64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b, vl32
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
; CHECK-NEXT: uaddv d0, p0, z0.b
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: ret
%op = load <32 x i8>, ptr %a
%op.zext = zext <32 x i8> %op to <32 x i64>
%res = call i64 @llvm.vector.reduce.add.v32i64(<32 x i64> %op.zext)
ret i64 %res
}

define i64 @uaddv_zext_v64i8_v64i64(ptr %a) vscale_range(4,0) #0 {
; CHECK-LABEL: uaddv_zext_v64i8_v64i64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b, vl64
; CHECK-NEXT: ld1b { z0.b }, p0/z, [x0]
; CHECK-NEXT: uaddv d0, p0, z0.b
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: ret
%op = load <64 x i8>, ptr %a
%op.zext = zext <64 x i8> %op to <64 x i64>
%res = call i64 @llvm.vector.reduce.add.v64i64(<64 x i64> %op.zext)
ret i64 %res
}

;
; SMAXV
;
Expand Down
7 changes: 1 addition & 6 deletions llvm/test/CodeGen/AArch64/sve-fixed-vector-zext.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -aarch64-sve-vector-bits-min=256 -verify-machineinstrs | FileCheck %s --check-prefixes=SVE256
; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-v1 -O3 -aarch64-sve-vector-bits-min=128 -verify-machineinstrs | FileCheck %s --check-prefixes=NEON
; RUN: llc < %s -mtriple=aarch64-none-linux-gnu -mcpu=neoverse-n1 -O3 -verify-machineinstrs | FileCheck %s --check-prefixes=NEON
Expand All @@ -9,11 +8,7 @@ define internal i32 @test(ptr nocapture readonly %p1, i32 %i1, ptr nocapture rea
; SVE256: ld1b { z0.h }, p0/z,
; SVE256: ld1b { z1.h }, p0/z,
; SVE256: sub z0.h, z0.h, z1.h
; SVE256-NEXT: sunpklo z1.s, z0.h
; SVE256-NEXT: ext z0.b, z0.b, z0.b, #16
; SVE256-NEXT: sunpklo z0.s, z0.h
; SVE256-NEXT: add z0.s, z1.s, z0.s
; SVE256-NEXT: uaddv d0, p1, z0.s
; SVE256-NEXT: saddv d0, p0, z0.h

; NEON-LABEL: test:
; NEON: ldr q0, [x0, w9, sxtw]
Expand Down
38 changes: 38 additions & 0 deletions llvm/test/CodeGen/AArch64/sve-int-reduce.ll
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,44 @@ define i64 @uaddv_nxv2i64(<vscale x 2 x i64> %a) {
ret i64 %res
}

define i16 @uaddv_nxv16i8_nxv16i16(<vscale x 16 x i8> %a) {
; CHECK-LABEL: uaddv_nxv16i8_nxv16i16:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: uaddv d0, p0, z0.b
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
; CHECK-NEXT: ret
%1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i16>
%2 = call i16 @llvm.vector.reduce.add.nxv16i16(<vscale x 16 x i16> %1)
ret i16 %2
}

define i32 @uaddv_nxv16i8_nxv16i32(<vscale x 16 x i8> %a) {
; CHECK-LABEL: uaddv_nxv16i8_nxv16i32:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: uaddv d0, p0, z0.b
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: // kill: def $w0 killed $w0 killed $x0
; CHECK-NEXT: ret
%1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
%2 = call i32 @llvm.vector.reduce.add.nxv16i32(<vscale x 16 x i32> %1)
ret i32 %2
}

define i64 @uaddv_nxv16i8_nxv16i64(<vscale x 16 x i8> %a) {
; CHECK-LABEL: uaddv_nxv16i8_nxv16i64:
; CHECK: // %bb.0:
; CHECK-NEXT: ptrue p0.b
; CHECK-NEXT: uaddv d0, p0, z0.b
; CHECK-NEXT: fmov x0, d0
; CHECK-NEXT: ret
%1 = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
%2 = call i64 @llvm.vector.reduce.add.nxv16i64(<vscale x 16 x i64> %1)
ret i64 %2
}

; UMINV

define i8 @umin_nxv16i8(<vscale x 16 x i8> %a) {
Expand Down
Loading