Skip to content

Commit 458c91d

Browse files
authored
[AArch64][NEON] Lower fixed-width add partial reductions to dot product (#107078)
This PR adds lowering for fixed-width <4 x i32> and <2 x i32> partial reductions to a dot product when Neon and the dot product feature are available. The work is by Max Beck-Jones (@DevM-uk).
1 parent c2e53b2 commit 458c91d

File tree

3 files changed

+146
-7
lines changed

3 files changed

+146
-7
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,7 +1995,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
19951995
return true;
19961996

19971997
EVT VT = EVT::getEVT(I->getType());
1998-
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
1998+
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
1999+
VT != MVT::v2i32;
19992000
}
20002001

20012002
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21807,7 +21808,10 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2180721808
Intrinsic::experimental_vector_partial_reduce_add &&
2180821809
"Expected a partial reduction node");
2180921810

21810-
if (!Subtarget->isSVEorStreamingSVEAvailable())
21811+
bool Scalable = N->getValueType(0).isScalableVector();
21812+
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
21813+
return SDValue();
21814+
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
2181121815
return SDValue();
2181221816

2181321817
SDLoc DL(N);
@@ -21844,11 +21848,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2184421848

2184521849
// Dot products operate on chunks of four elements so there must be four times
2184621850
// as many elements in the wide type
21847-
if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
21848-
return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
21849-
21850-
if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
21851-
return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
21851+
if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
21852+
(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
21853+
(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
21854+
(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21855+
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
2185221856

2185321857
return SDValue();
2185421858
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
3+
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
4+
5+
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
6+
; CHECK-DOT-LABEL: udot:
7+
; CHECK-DOT: // %bb.0:
8+
; CHECK-DOT-NEXT: udot v0.4s, v2.16b, v1.16b
9+
; CHECK-DOT-NEXT: ret
10+
;
11+
; CHECK-NODOT-LABEL: udot:
12+
; CHECK-NODOT: // %bb.0:
13+
; CHECK-NODOT-NEXT: umull v3.8h, v2.8b, v1.8b
14+
; CHECK-NODOT-NEXT: umull2 v1.8h, v2.16b, v1.16b
15+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
16+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v3.4h
17+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
18+
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
19+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
20+
; CHECK-NODOT-NEXT: ret
21+
%u.wide = zext <16 x i8> %u to <16 x i32>
22+
%s.wide = zext <16 x i8> %s to <16 x i32>
23+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
24+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
25+
ret <4 x i32> %partial.reduce
26+
}
27+
28+
define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
29+
; CHECK-DOT-LABEL: udot_narrow:
30+
; CHECK-DOT: // %bb.0:
31+
; CHECK-DOT-NEXT: udot v0.2s, v2.8b, v1.8b
32+
; CHECK-DOT-NEXT: ret
33+
;
34+
; CHECK-NODOT-LABEL: udot_narrow:
35+
; CHECK-NODOT: // %bb.0:
36+
; CHECK-NODOT-NEXT: umull v1.8h, v2.8b, v1.8b
37+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
38+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
39+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
40+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
41+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
42+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
43+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
44+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
45+
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
46+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
47+
; CHECK-NODOT-NEXT: ret
48+
%u.wide = zext <8 x i8> %u to <8 x i32>
49+
%s.wide = zext <8 x i8> %s to <8 x i32>
50+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
51+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
52+
ret <2 x i32> %partial.reduce
53+
}
54+
55+
define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
56+
; CHECK-DOT-LABEL: sdot:
57+
; CHECK-DOT: // %bb.0:
58+
; CHECK-DOT-NEXT: sdot v0.4s, v2.16b, v1.16b
59+
; CHECK-DOT-NEXT: ret
60+
;
61+
; CHECK-NODOT-LABEL: sdot:
62+
; CHECK-NODOT: // %bb.0:
63+
; CHECK-NODOT-NEXT: smull v3.8h, v2.8b, v1.8b
64+
; CHECK-NODOT-NEXT: smull2 v1.8h, v2.16b, v1.16b
65+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
66+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v3.4h
67+
; CHECK-NODOT-NEXT: saddw2 v2.4s, v2.4s, v3.8h
68+
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
69+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
70+
; CHECK-NODOT-NEXT: ret
71+
%u.wide = sext <16 x i8> %u to <16 x i32>
72+
%s.wide = sext <16 x i8> %s to <16 x i32>
73+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
74+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
75+
ret <4 x i32> %partial.reduce
76+
}
77+
78+
define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
79+
; CHECK-DOT-LABEL: sdot_narrow:
80+
; CHECK-DOT: // %bb.0:
81+
; CHECK-DOT-NEXT: sdot v0.2s, v2.8b, v1.8b
82+
; CHECK-DOT-NEXT: ret
83+
;
84+
; CHECK-NODOT-LABEL: sdot_narrow:
85+
; CHECK-NODOT: // %bb.0:
86+
; CHECK-NODOT-NEXT: smull v1.8h, v2.8b, v1.8b
87+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
88+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
89+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
90+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
91+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
92+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
93+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
94+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
95+
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
96+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
97+
; CHECK-NODOT-NEXT: ret
98+
%u.wide = sext <8 x i8> %u to <8 x i32>
99+
%s.wide = sext <8 x i8> %s to <8 x i32>
100+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
101+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
102+
ret <2 x i32> %partial.reduce
103+
}
104+
105+
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
106+
; CHECK-LABEL: not_udot:
107+
; CHECK: // %bb.0:
108+
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
109+
; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
110+
; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
111+
; CHECK-NEXT: ret
112+
%u.wide = zext <8 x i8> %u to <8 x i32>
113+
%s.wide = zext <8 x i8> %s to <8 x i32>
114+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
115+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <8 x i32> %mult)
116+
ret <4 x i32> %partial.reduce
117+
}
118+
119+
define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) {
120+
; CHECK-LABEL: not_udot_narrow:
121+
; CHECK: // %bb.0:
122+
; CHECK-NEXT: bic v1.4h, #255, lsl #8
123+
; CHECK-NEXT: bic v2.4h, #255, lsl #8
124+
; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
125+
; CHECK-NEXT: umull v3.4s, v2.4h, v1.4h
126+
; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
127+
; CHECK-NEXT: ext v1.16b, v3.16b, v3.16b, #8
128+
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
129+
; CHECK-NEXT: ret
130+
%u.wide = zext <4 x i8> %u to <4 x i32>
131+
%s.wide = zext <4 x i8> %s to <4 x i32>
132+
%mult = mul nuw nsw <4 x i32> %s.wide, %u.wide
133+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
134+
ret <2 x i32> %partial.reduce
135+
}

0 commit comments

Comments
 (0)