Skip to content

Commit 66ddce5

Browse files
SamTebbs33fhahn
authored andcommitted
[AArch64][NEON] Lower fixed-width add partial reductions to dot product (llvm#107078)
This PR adds lowering for fixed-width <4 x i32> and <2 x i32> partial reductions to a dot product when Neon and the dot product feature are available. The work is by Max Beck-Jones (@DevM-uk).
1 parent 198f6b9 commit 66ddce5

File tree

3 files changed

+146
-7
lines changed

3 files changed

+146
-7
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,7 +1977,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
19771977
return true;
19781978

19791979
EVT VT = EVT::getEVT(I->getType());
1980-
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64;
1980+
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
1981+
VT != MVT::v2i32;
19811982
}
19821983

19831984
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21268,7 +21269,10 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2126821269
Intrinsic::experimental_vector_partial_reduce_add &&
2126921270
"Expected a partial reduction node");
2127021271

21271-
if (!Subtarget->isSVEorStreamingSVEAvailable())
21272+
bool Scalable = N->getValueType(0).isScalableVector();
21273+
if (Scalable && !Subtarget->isSVEorStreamingSVEAvailable())
21274+
return SDValue();
21275+
if (!Scalable && (!Subtarget->isNeonAvailable() || !Subtarget->hasDotProd()))
2127221276
return SDValue();
2127321277

2127421278
SDLoc DL(N);
@@ -21305,11 +21309,11 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2130521309

2130621310
// Dot products operate on chunks of four elements so there must be four times
2130721311
// as many elements in the wide type
21308-
if (ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8)
21309-
return DAG.getNode(Opcode, DL, MVT::nxv4i32, NarrowOp, A, B);
21310-
21311-
if (ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16)
21312-
return DAG.getNode(Opcode, DL, MVT::nxv2i64, NarrowOp, A, B);
21312+
if ((ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) ||
21313+
(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) ||
21314+
(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) ||
21315+
(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21316+
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
2131321317

2131421318
return SDValue();
2131521319
}
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -mtriple aarch64 -mattr=+neon,+dotprod < %s | FileCheck %s --check-prefixes=CHECK,CHECK-DOT
3+
; RUN: llc -mtriple aarch64 -mattr=+neon < %s | FileCheck %s --check-prefixes=CHECK,CHECK-NODOT
4+
5+
define <4 x i32> @udot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
6+
; CHECK-DOT-LABEL: udot:
7+
; CHECK-DOT: // %bb.0:
8+
; CHECK-DOT-NEXT: udot v0.4s, v2.16b, v1.16b
9+
; CHECK-DOT-NEXT: ret
10+
;
11+
; CHECK-NODOT-LABEL: udot:
12+
; CHECK-NODOT: // %bb.0:
13+
; CHECK-NODOT-NEXT: umull v3.8h, v2.8b, v1.8b
14+
; CHECK-NODOT-NEXT: umull2 v1.8h, v2.16b, v1.16b
15+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
16+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v3.4h
17+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v2.4s, v3.8h
18+
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
19+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
20+
; CHECK-NODOT-NEXT: ret
21+
%u.wide = zext <16 x i8> %u to <16 x i32>
22+
%s.wide = zext <16 x i8> %s to <16 x i32>
23+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
24+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
25+
ret <4 x i32> %partial.reduce
26+
}
27+
28+
define <2 x i32> @udot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
29+
; CHECK-DOT-LABEL: udot_narrow:
30+
; CHECK-DOT: // %bb.0:
31+
; CHECK-DOT-NEXT: udot v0.2s, v2.8b, v1.8b
32+
; CHECK-DOT-NEXT: ret
33+
;
34+
; CHECK-NODOT-LABEL: udot_narrow:
35+
; CHECK-NODOT: // %bb.0:
36+
; CHECK-NODOT-NEXT: umull v1.8h, v2.8b, v1.8b
37+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
38+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
39+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
40+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
41+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
42+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
43+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
44+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
45+
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
46+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
47+
; CHECK-NODOT-NEXT: ret
48+
%u.wide = zext <8 x i8> %u to <8 x i32>
49+
%s.wide = zext <8 x i8> %s to <8 x i32>
50+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
51+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
52+
ret <2 x i32> %partial.reduce
53+
}
54+
55+
define <4 x i32> @sdot(<4 x i32> %acc, <16 x i8> %u, <16 x i8> %s) {
56+
; CHECK-DOT-LABEL: sdot:
57+
; CHECK-DOT: // %bb.0:
58+
; CHECK-DOT-NEXT: sdot v0.4s, v2.16b, v1.16b
59+
; CHECK-DOT-NEXT: ret
60+
;
61+
; CHECK-NODOT-LABEL: sdot:
62+
; CHECK-NODOT: // %bb.0:
63+
; CHECK-NODOT-NEXT: smull v3.8h, v2.8b, v1.8b
64+
; CHECK-NODOT-NEXT: smull2 v1.8h, v2.16b, v1.16b
65+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
66+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v3.4h
67+
; CHECK-NODOT-NEXT: saddw2 v2.4s, v2.4s, v3.8h
68+
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
69+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
70+
; CHECK-NODOT-NEXT: ret
71+
%u.wide = sext <16 x i8> %u to <16 x i32>
72+
%s.wide = sext <16 x i8> %s to <16 x i32>
73+
%mult = mul nuw nsw <16 x i32> %s.wide, %u.wide
74+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %mult)
75+
ret <4 x i32> %partial.reduce
76+
}
77+
78+
define <2 x i32> @sdot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
79+
; CHECK-DOT-LABEL: sdot_narrow:
80+
; CHECK-DOT: // %bb.0:
81+
; CHECK-DOT-NEXT: sdot v0.2s, v2.8b, v1.8b
82+
; CHECK-DOT-NEXT: ret
83+
;
84+
; CHECK-NODOT-LABEL: sdot_narrow:
85+
; CHECK-NODOT: // %bb.0:
86+
; CHECK-NODOT-NEXT: smull v1.8h, v2.8b, v1.8b
87+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
88+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
89+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
90+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
91+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
92+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
93+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
94+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
95+
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
96+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
97+
; CHECK-NODOT-NEXT: ret
98+
%u.wide = sext <8 x i8> %u to <8 x i32>
99+
%s.wide = sext <8 x i8> %s to <8 x i32>
100+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
101+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <8 x i32> %mult)
102+
ret <2 x i32> %partial.reduce
103+
}
104+
105+
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) {
106+
; CHECK-LABEL: not_udot:
107+
; CHECK: // %bb.0:
108+
; CHECK-NEXT: umull v1.8h, v2.8b, v1.8b
109+
; CHECK-NEXT: uaddw v0.4s, v0.4s, v1.4h
110+
; CHECK-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
111+
; CHECK-NEXT: ret
112+
%u.wide = zext <8 x i8> %u to <8 x i32>
113+
%s.wide = zext <8 x i8> %s to <8 x i32>
114+
%mult = mul nuw nsw <8 x i32> %s.wide, %u.wide
115+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <8 x i32> %mult)
116+
ret <4 x i32> %partial.reduce
117+
}
118+
119+
define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) {
120+
; CHECK-LABEL: not_udot_narrow:
121+
; CHECK: // %bb.0:
122+
; CHECK-NEXT: bic v1.4h, #255, lsl #8
123+
; CHECK-NEXT: bic v2.4h, #255, lsl #8
124+
; CHECK-NEXT: // kill: def $d0 killed $d0 def $q0
125+
; CHECK-NEXT: umull v3.4s, v2.4h, v1.4h
126+
; CHECK-NEXT: umlal v0.4s, v2.4h, v1.4h
127+
; CHECK-NEXT: ext v1.16b, v3.16b, v3.16b, #8
128+
; CHECK-NEXT: add v0.2s, v1.2s, v0.2s
129+
; CHECK-NEXT: ret
130+
%u.wide = zext <4 x i8> %u to <4 x i32>
131+
%s.wide = zext <4 x i8> %s to <4 x i32>
132+
%mult = mul nuw nsw <4 x i32> %s.wide, %u.wide
133+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
134+
ret <2 x i32> %partial.reduce
135+
}

0 commit comments

Comments
 (0)