-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64][SVE] Add codegen support for partial reduction lowering to wide add instructions #114406
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…wide add instructions
@llvm/pr-subscribers-backend-aarch64 Author: James Chesterman (JamesChesterman) ChangesFor partial reductions in the situation of the number of lanes being halved and the bits per lane being doubled, a pair of wide add instructions can be used. Full diff: https://github.com/llvm/llvm-project/pull/114406.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 4c0cd1ac3d4512..8efc8244426ef3 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -2042,7 +2042,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
EVT VT = EVT::getEVT(I->getType());
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
- VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
+ VT != MVT::nxv8i16 && VT != MVT::v4i64 && VT != MVT::v4i32 &&
+ VT != MVT::v2i32 && VT != MVT::v8i16;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21783,6 +21784,62 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}
+SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
+ const AArch64Subtarget *Subtarget,
+ SelectionDAG &DAG) {
+
+ assert(N->getOpcode() == ISD::INTRINSIC_WO_CHAIN &&
+ getIntrinsicID(N) ==
+ Intrinsic::experimental_vector_partial_reduce_add &&
+ "Expected a partial reduction node");
+
+ bool Scalable = N->getValueType(0).isScalableVector();
+ if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
+ return SDValue();
+
+ SDLoc DL(N);
+
+ auto Accumulator = N->getOperand(1);
+ auto ExtInput = N->getOperand(2);
+
+ EVT AccumulatorType = Accumulator.getValueType();
+ EVT AccumulatorElementType = AccumulatorType.getVectorElementType();
+
+ if (ExtInput.getValueType().getVectorElementType() != AccumulatorElementType)
+ return SDValue();
+
+ unsigned ExtInputOpcode = ExtInput->getOpcode();
+ if (!ISD::isExtOpcode(ExtInputOpcode))
+ return SDValue();
+
+ auto Input = ExtInput->getOperand(0);
+ EVT InputType = Input.getValueType();
+
+ // To do this transformation, output element size needs to be double input
+ // element size, and output number of elements needs to be half the input
+ // number of elements
+ if (!(InputType.getVectorElementType().getSizeInBits() * 2 ==
+ AccumulatorElementType.getSizeInBits()) ||
+ !(AccumulatorType.getVectorElementCount() * 2 ==
+ InputType.getVectorElementCount()) ||
+ !(AccumulatorType.isScalableVector() == InputType.isScalableVector()))
+ return SDValue();
+
+ bool InputIsSigned = ExtInputOpcode == ISD::SIGN_EXTEND;
+ auto BottomIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwb
+ : Intrinsic::aarch64_sve_uaddwb;
+ auto TopIntrinsic = InputIsSigned ? Intrinsic::aarch64_sve_saddwt
+ : Intrinsic::aarch64_sve_uaddwt;
+
+ auto BottomID =
+ DAG.getTargetConstant(BottomIntrinsic, DL, AccumulatorElementType);
+ auto BottomNode = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType,
+ BottomID, Accumulator, Input);
+ auto TopID = DAG.getTargetConstant(TopIntrinsic, DL, AccumulatorElementType);
+ return DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, AccumulatorType, TopID,
+ BottomNode, Input);
+}
+
static SDValue performIntrinsicCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
@@ -21794,6 +21851,8 @@ static SDValue performIntrinsicCombine(SDNode *N,
case Intrinsic::experimental_vector_partial_reduce_add: {
if (auto Dot = tryLowerPartialReductionToDot(N, Subtarget, DAG))
return Dot;
+ if (auto WideAdd = tryLowerPartialReductionToWideAdd(N, Subtarget, DAG))
+ return WideAdd;
return DAG.getPartialReduceAdd(SDLoc(N), N->getValueType(0),
N->getOperand(1), N->getOperand(2));
}
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
new file mode 100644
index 00000000000000..6fe3da2a25c0cd
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-wide-add.ll
@@ -0,0 +1,74 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s
+
+define <vscale x 2 x i64> @signed_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: signed_wide_add_nxv4i32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: saddwb z0.d, z0.d, z1.s
+; CHECK-NEXT: saddwt z0.d, z0.d, z1.s
+; CHECK-NEXT: ret
+entry:
+ %input.wide = sext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @unsigned_wide_add_nxv4i32(<vscale x 2 x i64> %acc, <vscale x 4 x i32> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv4i32:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uaddwb z0.d, z0.d, z1.s
+; CHECK-NEXT: uaddwt z0.d, z0.d, z1.s
+; CHECK-NEXT: ret
+entry:
+ %input.wide = zext <vscale x 4 x i32> %input to <vscale x 4 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv4i64(<vscale x 2 x i64> %acc, <vscale x 4 x i64> %input.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i32> @signed_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: signed_wide_add_nxv8i16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: saddwb z0.s, z0.s, z1.h
+; CHECK-NEXT: saddwt z0.s, z0.s, z1.h
+; CHECK-NEXT: ret
+entry:
+ %input.wide = sext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @unsigned_wide_add_nxv8i16(<vscale x 4 x i32> %acc, <vscale x 8 x i16> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv8i16:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uaddwb z0.s, z0.s, z1.h
+; CHECK-NEXT: uaddwt z0.s, z0.s, z1.h
+; CHECK-NEXT: ret
+entry:
+ %input.wide = zext <vscale x 8 x i16> %input to <vscale x 8 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv8i32(<vscale x 4 x i32> %acc, <vscale x 8 x i32> %input.wide)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 8 x i16> @signed_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: signed_wide_add_nxv16i8:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: saddwb z0.h, z0.h, z1.b
+; CHECK-NEXT: saddwt z0.h, z0.h, z1.b
+; CHECK-NEXT: ret
+entry:
+ %input.wide = sext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+ %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+ ret <vscale x 8 x i16> %partial.reduce
+}
+
+define <vscale x 8 x i16> @unsigned_wide_add_nxv16i8(<vscale x 8 x i16> %acc, <vscale x 16 x i8> %input){
+; CHECK-LABEL: unsigned_wide_add_nxv16i8:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: uaddwb z0.h, z0.h, z1.b
+; CHECK-NEXT: uaddwt z0.h, z0.h, z1.b
+; CHECK-NEXT: ret
+entry:
+ %input.wide = zext <vscale x 16 x i8> %input to <vscale x 16 x i16>
+ %partial.reduce = tail call <vscale x 8 x i16> @llvm.experimental.vector.partial.reduce.add.nxv8i16.nxv16i16(<vscale x 8 x i16> %acc, <vscale x 16 x i16> %input.wide)
+ ret <vscale x 8 x i16> %partial.reduce
+}
|
Rename variables, eliminate a redundant condition in an if statement and refactor code checking types
Check for non-legal types for the wide add instructions so that instructions that require promotion or splitting just go to the default case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM (just one nit). I'd wait for @huntergr-arm to approve before landing though 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG
For partial reductions in the situation of the number of lanes being halved and the bits per lane being doubled, a pair of wide add instructions can be used.