-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64][NEON] Lower fixed-width add partial reductions to dot product #107078
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AArch64][NEON] Lower fixed-width add partial reductions to dot product #107078
Conversation
This PR adds lowering for fixed-width <4 x i32> and <2 x i32> partial reductions to a dot product when Neon and the dot product feature are available. The work is by Max Beck-Jones (@DevM-uk).
@llvm/pr-subscribers-backend-aarch64 Author: Sam Tebbs (SamTebbs33) ChangesThis PR adds lowering for fixed-width <4 x i32> and <2 x i32> partial reductions to a dot product when Neon and the dot product feature are available. The work is by Max Beck-Jones (@DevM-uk). Full diff: https://github.com/llvm/llvm-project/pull/107078.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 1735ff5cd69748..f3298a326bf4c1 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -1994,7 +1994,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;
EVT VT = EVT::getEVT(I->getType());
- return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
+ return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
+ VT != MVT::v2i32;
}
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21781,7 +21782,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
Intrinsic::experimental_vector_partial_reduce_add &&
"Expected a partial reduction node");
- if (!Subtarget->isSVEorStreamingSVEAvailable())
+ if (!Subtarget->isSVEorStreamingSVEAvailable() &&
+ !(Subtarget->isNeonAvailable() && Subtarget->hasDotProd()))
return SDValue();
SDLoc DL(N);
@@ -21818,11 +21820,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
- return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
-
- if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
- return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
+ if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
+ (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
+ (ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
+ (ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
return SDValue();
}
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
new file mode 100644
index 00000000000000..13b731451b60c1
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -0,0 +1,209 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
+; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s
+; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefix=CHECK-NODOTPROD
+
+define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-LABEL: udot:
+; CHECK: // %bb.0:
+; CHECK-NEXT: udot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: udot:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: umull v3.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT: umull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOTPROD-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT: uaddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = zext <16 x i8> %u to <16 x i32>
+ %s.wide = zext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: udot_narrow:
+; CHECK: // %bb.0:
+; CHECK-NEXT: udot v0.2s, v2.8b, v1.8b
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: udot_narrow:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOTPROD-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOTPROD-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOTPROD-NEXT: uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOTPROD-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = zext <8 x i8> %u to <8 x i32>
+ %s.wide = zext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) #0{
+; CHECK-LABEL: sdot:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sdot v0.4s, v2.16b, v1.16b
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: sdot:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: smull v3.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT: smull2 v1.8h, v2.16b, v1.16b
+; CHECK-NODOTPROD-NEXT: sshll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT: saddw v0.4s, v0.4s, v3.4h
+; CHECK-NODOTPROD-NEXT: saddw2 v2.4s, v2.4s, v3.8h
+; CHECK-NODOTPROD-NEXT: saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = sext <16 x i8> %u to <16 x i32>
+ %s.wide = sext <16 x i8> %s to <16 x i32>
+ %mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: sdot_narrow:
+; CHECK: // %bb.0:
+; CHECK-NEXT: sdot v0.2s, v2.8b, v1.8b
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: sdot_narrow:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT: sshll v2.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT: sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOTPROD-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOTPROD-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOTPROD-NEXT: saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOTPROD-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = sext <8 x i8> %u to <8 x i32>
+ %s.wide = sext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: not_udot:
+; CHECK: // %bb.0:
+; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
+; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: not_udot:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: umull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = zext <8 x i8> %u to <8 x i32>
+ %s.wide = zext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <8 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) #0{
+; CHECK-LABEL: not_udot_narrow:
+; CHECK: // %bb.0:
+; CHECK-NEXT: bic v1.4h, #255, lsl #8
+; CHECK-NEXT: bic v2.4h, #255, lsl #8
+; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT: umull v3.4s, v2.4h, v1.4h
+; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
+; CHECK-NEXT: ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: not_udot_narrow:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: bic v1.4h, #255, lsl #8
+; CHECK-NODOTPROD-NEXT: bic v2.4h, #255, lsl #8
+; CHECK-NODOTPROD-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOTPROD-NEXT: umull v3.4s, v2.4h, v1.4h
+; CHECK-NODOTPROD-NEXT: umlal v0.4s, v2.4h, v1.4h
+; CHECK-NODOTPROD-NEXT: ext v1.16b, v3.16b, v3.16b, #8
+; CHECK-NODOTPROD-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = zext <4 x i8> %u to <4 x i32>
+ %s.wide = zext <4 x i8> %s to <4 x i32>
+ %mult = mul nuw nsw <4 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i32> @not_sdot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
+; CHECK-LABEL: not_sdot:
+; CHECK: // %bb.0:
+; CHECK-NEXT: smull v1.8h, v2.8b, v1.8b
+; CHECK-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NEXT: saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: not_sdot:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: smull v1.8h, v2.8b, v1.8b
+; CHECK-NODOTPROD-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOTPROD-NEXT: saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = sext <8 x i8> %u to <8 x i32>
+ %s.wide = sext <8 x i8> %s to <8 x i32>
+ %mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <8 x i32> %mult)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @not_sdot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) #0{
+; CHECK-LABEL: not_sdot_narrow:
+; CHECK: // %bb.0:
+; CHECK-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-NEXT: shl v1.4s, v1.4s, #24
+; CHECK-NEXT: shl v2.4s, v2.4s, #24
+; CHECK-NEXT: sshr v1.4s, v1.4s, #24
+; CHECK-NEXT: sshr v2.4s, v2.4s, #24
+; CHECK-NEXT: mul v1.4s, v2.4s, v1.4s
+; CHECK-NEXT: ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-NEXT: add v0.2s, v0.2s, v1.2s
+; CHECK-NEXT: add v0.2s, v2.2s, v0.2s
+; CHECK-NEXT: ret
+;
+; CHECK-NODOTPROD-LABEL: not_sdot_narrow:
+; CHECK-NODOTPROD: // %bb.0:
+; CHECK-NODOTPROD-NEXT: ushll v2.4s, v2.4h, #0
+; CHECK-NODOTPROD-NEXT: ushll v1.4s, v1.4h, #0
+; CHECK-NODOTPROD-NEXT: shl v1.4s, v1.4s, #24
+; CHECK-NODOTPROD-NEXT: shl v2.4s, v2.4s, #24
+; CHECK-NODOTPROD-NEXT: sshr v1.4s, v1.4s, #24
+; CHECK-NODOTPROD-NEXT: sshr v2.4s, v2.4s, #24
+; CHECK-NODOTPROD-NEXT: mul v1.4s, v2.4s, v1.4s
+; CHECK-NODOTPROD-NEXT: ext v2.16b, v1.16b, v1.16b, #8
+; CHECK-NODOTPROD-NEXT: add v0.2s, v0.2s, v1.2s
+; CHECK-NODOTPROD-NEXT: add v0.2s, v2.2s, v0.2s
+; CHECK-NODOTPROD-NEXT: ret
+ %u.wide = sext <4 x i8> %u to <4 x i32>
+ %s.wide = sext <4 x i8> %s to <4 x i32>
+ %mult = mul nuw nsw <4 x i32> %s.wide, %u.wide
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
+ ret <2 x i32> %partial.reduce
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
@@ -21781,7 +21782,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N, | |||
Intrinsic::experimental_vector_partial_reduce_add && | |||
"Expected a partial reduction node"); | |||
|
|||
if (!Subtarget->isSVEorStreamingSVEAvailable()) | |||
if (!Subtarget->isSVEorStreamingSVEAvailable() && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if this is <4 x i32> and we don't have NEON's dot product, but we do have SVE? Will we do something sensible? If we lower fixed-width using SVE, then we should add an extra RUN line to the tests for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for spotting that missing, I've modified the check to also check for the dotprod feature when SVE is available and we're using a fixed-width vector.
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s | ||
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefix=CHECK-NODOTPROD |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
By using --check-prefixes=
you'll only generate a single set of check lines for the common cases (i.e. the negative tests).
For example:
--check-prefixes=CHECK,DOT
--check-prefixes=CHECK,NODOT
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
ret <2 x i32> %partial.reduce | ||
} | ||
|
||
define <4 x i32> @not_sdot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given shouldExpandPartialReductionIntrinsic
only considers operand types, do not_sdot
and not_sdot_narrow
add anything to the testing? If not then please remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
shouldExpandPartialReductionIntrinsic
only considers the return type and tryLowerPartialReductionToDot
checks the operand types so these do test some of those invalid operand types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Restricting my comment to shouldExpandPartialReductionIntrinsic
was an error, but what I meant is the only difference between not_udot
vs not_sdot
and not_udot_narrow
vs not_sdot_narrow
is the type of extension and so there is nothing in not_sdot
and not_sdot_narrow
that's not already being tested by not_udot
and not_udot_narrow
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see, done.
if (!Subtarget->isSVEorStreamingSVEAvailable() && | ||
!Subtarget->isNeonAvailable()) | ||
return SDValue(); | ||
|
||
// Fixed-width requires the dotprod feature, both for Neon and SVE | ||
if (!N->getValueType(0).isScalableVT() && !Subtarget->hasDotProd()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd rather have the association between type and feature explicit. Something more akin to:
if (vt.isScalableVector() && !SVE)
bail;
if (vt.isFixedLnegth() && (!NEON || !DOT)
bail;
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, not sure how I missed this file before but please revert the changes because there's nothing relevant to SVE here.
If you want, perhaps just rename the original file to sve-partial-reduce-dot-product.ll.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done, thank you.
…ct (llvm#107078) This PR adds lowering for fixed-width <4 x i32> and <2 x i32> partial reductions to a dot product when Neon and the dot product feature are available. The work is by Max Beck-Jones (@DevM-uk).
This PR adds lowering for fixed-width <4 x i32> and <2 x i32> partial reductions to a dot product when Neon and the dot product feature are available.
The work is by Max Beck-Jones (@DevM-uk).