Skip to content

[AArch64][SVE] Add codegen support for partial reduction lowering to wide add instructions #114406

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 12, 2024

Conversation

JamesChesterman
Copy link
Contributor

For partial reductions in the situation of the number of lanes being halved and the bits per lane being doubled, a pair of wide add instructions can be used.

@llvmbot
Copy link
Member

llvmbot commented Oct 31, 2024

@llvm/pr-subscribers-backend-aarch64

Author: James Chesterman (JamesChesterman)

Changes

For partial reductions in the situation of the number of lanes being halved and the bits per lane being doubled, a pair of wide add instructions can be used.


Full diff: https://github.com/llvm/llvm-project/pull/114406.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+60-1)
  • (added) llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll (+74)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4c0cd1ac3d4512..8efc8244426ef3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2042,7 +2042,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
 
   EVT VT = EVT::getEVT(I->getType());
   return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
-         VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
+         VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
+         VT != MVT::v2i32 && VT != MVT::v8i16;
 }
 
 bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21783,6 +21784,62 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
   return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
 }
 
+SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+                                          const AArch64Subtarget *Subtarget,
+                                          SelectionDAG &DAG) {
+
+  assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+         getIntrinsicID(N) ==
+             Intrinsic::experimental_vector_partial_reduce_add &&
+         "Expected a partial reduction node");
+
+  bool Scalable = N->getValueType(0).isScalableVector();
+  if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+    return SDValue();
+
+  SDLoc DL(N);
+
+  auto Accumulator = N->getOperand(1);
+  auto ExtInput = N->getOperand(2);
+
+  EVT AccumulatorType = Accumulator.getValueType();
+  EVT AccumulatorElementType = AccumulatorType.getVectorElementType();
+
+  if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType)
+    return SDValue();
+
+  unsigned ExtInputOpcode = ExtInput->getOpcode();
+  if (!ISD::isExtOpcode(ExtInputOpcode))
+    return SDValue();
+
+  auto Input = ExtInput->getOperand(0);
+  EVT InputType = Input.getValueType();
+
+  // To do this transformation, output element size needs to be double input
+  // element size, and output number of elements needs to be half the input
+  // number of elements
+  if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
+        AccumulatorElementType.getSizeInBits()) ||
+      !(AccumulatorType.getVectorElementCount() * 2 ==
+        InputType.getVectorElementCount()) ||
+      !(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
+    return SDValue();
+
+  bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
+  auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
+                                       : Intrinsic::aarch64_sve_uaddwb;
+  auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
+                                    : Intrinsic::aarch64_sve_uaddwt;
+
+  auto BottomID =
+      DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
+  auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
+                                BottomID, Accumulator, Input);
+  auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
+  return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
+                     BottomNode, Input);
+}
+
 static SDValue performIntrinsicCombine(SDNode *N,
                                        TargetLowering::DAGCombinerInfo &DCI,
                                        const AArch64Subtarget *Subtarget) {
@@ -21794,6 +21851,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
   case Intrinsic::experimental_vector_partial_reduce_add: {
     if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
       return Dot;
+    if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
+      return WideAdd;
     return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
                                    N->getOperand(1), N->getOperand(2));
   }
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
new file mode 100644
index 00000000000000..6fe3da2a25c0cd
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+
+define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: signed_wide_add_nxv4i32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    saddwb z0.d, z0.d, z1.s
+; CHECK-NEXT:    saddwt z0.d, z0.d, z1.s
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+    %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+    ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uaddwb z0.d, z0.d, z1.s
+; CHECK-NEXT:    uaddwt z0.d, z0.d, z1.s
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+    %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+    ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: signed_wide_add_nxv8i16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    saddwb z0.s, z0.s, z1.h
+; CHECK-NEXT:    saddwt z0.s, z0.s, z1.h
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+    %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+    ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uaddwb z0.s, z0.s, z1.h
+; CHECK-NEXT:    uaddwt z0.s, z0.s, z1.h
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+    %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+    ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: signed_wide_add_nxv16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    saddwb z0.h, z0.h, z1.b
+; CHECK-NEXT:    saddwt z0.h, z0.h, z1.b
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+    %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+    ret <vscale x 8 x i16> %partial.reduce
+}
+
+define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK:       // %bb.0: // %entry
+; CHECK-NEXT:    uaddwb z0.h, z0.h, z1.b
+; CHECK-NEXT:    uaddwt z0.h, z0.h, z1.b
+; CHECK-NEXT:    ret
+entry:
+    %input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+    %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+    ret <vscale x 8 x i16> %partial.reduce
+}

Rename variables, eliminate a redundant condition in an if
statement and refactor code checking types
Check for non-legal types for the wide add instructions so that
instructions that require promotion or splitting just go to the
default case.
Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM (just one nit). I'd wait for @huntergr-arm to approve before landing though 🙂

Copy link
Collaborator

@huntergr-arm huntergr-arm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG

@JamesChesterman JamesChesterman merged commit c3c2e1e into llvm:main Nov 12, 2024
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants