Skip to content

Commit b2a6814

Browse files
[AArch64][NEON][SVE] Lower i8 to i64 partial reduction to a dot product (#110220)
An i8 to i64 partial reduction can instead be done with an i8 to i32 dot product followed by a sign extension.
1 parent 4b3ba64 commit b2a6814

File tree

3 files changed

+366
-4
lines changed

3 files changed

+366
-4
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1996,8 +1996,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
19961996
return true;
19971997

19981998
EVT VT = EVT::getEVT(I->getType());
1999-
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
2000-
VT != MVT::v2i32;
1999+
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
2000+
VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
20012001
}
20022002

20032003
bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
@@ -21918,8 +21918,10 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2191821918

2191921919
// Dot products operate on chunks of four elements so there must be four times
2192021920
// as many elements in the wide type
21921-
if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
21921+
if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
21922+
!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
2192221923
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
21924+
!(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
2192321925
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
2192421926
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
2192521927
return SDValue();
@@ -21932,7 +21934,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2193221934

2193321935
bool Scalable = N->getValueType(0).isScalableVT();
2193421936
// There's no nxv2i64 version of usdot
21935-
if (Scalable && ReducedType != MVT::nxv4i32)
21937+
if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
2193621938
return SDValue();
2193721939

2193821940
Opcode = AArch64ISD::USDOT;
@@ -21944,6 +21946,20 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2194421946
else
2194521947
Opcode = AArch64ISD::UDOT;
2194621948

21949+
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
21950+
// product followed by a zero / sign extension
21951+
if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
21952+
(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
21953+
EVT ReducedTypeI32 =
21954+
(ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
21955+
21956+
auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
21957+
DAG.getConstant(0, DL, ReducedTypeI32), A, B);
21958+
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
21959+
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
21960+
Extended);
21961+
}
21962+
2194721963
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
2194821964
}
2194921965

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,162 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
211211
ret <2 x i32> %partial.reduce
212212
}
213213

214+
define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
215+
; CHECK-DOT-LABEL: udot_8to64:
216+
; CHECK-DOT: // %bb.0: // %entry
217+
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
218+
; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
219+
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
220+
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
221+
; CHECK-DOT-NEXT: ret
222+
;
223+
; CHECK-NODOT-LABEL: udot_8to64:
224+
; CHECK-NODOT: // %bb.0: // %entry
225+
; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
226+
; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
227+
; CHECK-NODOT-NEXT: ushll v3.4s, v4.4h, #0
228+
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
229+
; CHECK-NODOT-NEXT: ushll2 v4.4s, v4.8h, #0
230+
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
231+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
232+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v3.2s
233+
; CHECK-NODOT-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
234+
; CHECK-NODOT-NEXT: uaddl v4.2d, v4.2s, v5.2s
235+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
236+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
237+
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
238+
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
239+
; CHECK-NODOT-NEXT: ret
240+
entry:
241+
%a.wide = zext <16 x i8> %a to <16 x i64>
242+
%b.wide = zext <16 x i8> %b to <16 x i64>
243+
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
244+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
245+
<4 x i64> %acc, <16 x i64> %mult)
246+
ret <4 x i64> %partial.reduce
247+
}
248+
249+
define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
250+
; CHECK-DOT-LABEL: sdot_8to64:
251+
; CHECK-DOT: // %bb.0: // %entry
252+
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
253+
; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
254+
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
255+
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
256+
; CHECK-DOT-NEXT: ret
257+
;
258+
; CHECK-NODOT-LABEL: sdot_8to64:
259+
; CHECK-NODOT: // %bb.0: // %entry
260+
; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
261+
; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b
262+
; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0
263+
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
264+
; CHECK-NODOT-NEXT: sshll2 v4.4s, v4.8h, #0
265+
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
266+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v3.4s
267+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v3.2s
268+
; CHECK-NODOT-NEXT: saddl2 v3.2d, v4.4s, v5.4s
269+
; CHECK-NODOT-NEXT: saddl v4.2d, v4.2s, v5.2s
270+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
271+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
272+
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
273+
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
274+
; CHECK-NODOT-NEXT: ret
275+
entry:
276+
%a.wide = sext <16 x i8> %a to <16 x i64>
277+
%b.wide = sext <16 x i8> %b to <16 x i64>
278+
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
279+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
280+
<4 x i64> %acc, <16 x i64> %mult)
281+
ret <4 x i64> %partial.reduce
282+
}
283+
284+
define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
285+
; CHECK-NOI8MM-LABEL: usdot_8to64:
286+
; CHECK-NOI8MM: // %bb.0: // %entry
287+
; CHECK-NOI8MM-NEXT: ushll v4.8h, v2.8b, #0
288+
; CHECK-NOI8MM-NEXT: sshll v5.8h, v3.8b, #0
289+
; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0
290+
; CHECK-NOI8MM-NEXT: sshll2 v3.8h, v3.16b, #0
291+
; CHECK-NOI8MM-NEXT: ushll v6.4s, v4.4h, #0
292+
; CHECK-NOI8MM-NEXT: sshll v7.4s, v5.4h, #0
293+
; CHECK-NOI8MM-NEXT: ushll2 v4.4s, v4.8h, #0
294+
; CHECK-NOI8MM-NEXT: sshll2 v5.4s, v5.8h, #0
295+
; CHECK-NOI8MM-NEXT: ushll2 v16.4s, v2.8h, #0
296+
; CHECK-NOI8MM-NEXT: sshll2 v17.4s, v3.8h, #0
297+
; CHECK-NOI8MM-NEXT: ushll v2.4s, v2.4h, #0
298+
; CHECK-NOI8MM-NEXT: sshll v3.4s, v3.4h, #0
299+
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
300+
; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
301+
; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
302+
; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
303+
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
304+
; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
305+
; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
306+
; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
307+
; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
308+
; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
309+
; CHECK-NOI8MM-NEXT: ret
310+
;
311+
; CHECK-I8MM-LABEL: usdot_8to64:
312+
; CHECK-I8MM: // %bb.0: // %entry
313+
; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
314+
; CHECK-I8MM-NEXT: usdot v4.4s, v2.16b, v3.16b
315+
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
316+
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
317+
; CHECK-I8MM-NEXT: ret
318+
entry:
319+
%a.wide = zext <16 x i8> %a to <16 x i64>
320+
%b.wide = sext <16 x i8> %b to <16 x i64>
321+
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
322+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
323+
<4 x i64> %acc, <16 x i64> %mult)
324+
ret <4 x i64> %partial.reduce
325+
}
326+
327+
define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
328+
; CHECK-NOI8MM-LABEL: sudot_8to64:
329+
; CHECK-NOI8MM: // %bb.0: // %entry
330+
; CHECK-NOI8MM-NEXT: sshll v4.8h, v2.8b, #0
331+
; CHECK-NOI8MM-NEXT: ushll v5.8h, v3.8b, #0
332+
; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0
333+
; CHECK-NOI8MM-NEXT: ushll2 v3.8h, v3.16b, #0
334+
; CHECK-NOI8MM-NEXT: sshll v6.4s, v4.4h, #0
335+
; CHECK-NOI8MM-NEXT: ushll v7.4s, v5.4h, #0
336+
; CHECK-NOI8MM-NEXT: sshll2 v4.4s, v4.8h, #0
337+
; CHECK-NOI8MM-NEXT: ushll2 v5.4s, v5.8h, #0
338+
; CHECK-NOI8MM-NEXT: sshll2 v16.4s, v2.8h, #0
339+
; CHECK-NOI8MM-NEXT: ushll2 v17.4s, v3.8h, #0
340+
; CHECK-NOI8MM-NEXT: sshll v2.4s, v2.4h, #0
341+
; CHECK-NOI8MM-NEXT: ushll v3.4s, v3.4h, #0
342+
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
343+
; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
344+
; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
345+
; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
346+
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
347+
; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
348+
; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
349+
; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
350+
; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
351+
; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
352+
; CHECK-NOI8MM-NEXT: ret
353+
;
354+
; CHECK-I8MM-LABEL: sudot_8to64:
355+
; CHECK-I8MM: // %bb.0: // %entry
356+
; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
357+
; CHECK-I8MM-NEXT: usdot v4.4s, v3.16b, v2.16b
358+
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
359+
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
360+
; CHECK-I8MM-NEXT: ret
361+
entry:
362+
%a.wide = sext <16 x i8> %a to <16 x i64>
363+
%b.wide = zext <16 x i8> %b to <16 x i64>
364+
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
365+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
366+
<4 x i64> %acc, <16 x i64> %mult)
367+
ret <4 x i64> %partial.reduce
368+
}
369+
214370
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
215371
; CHECK-LABEL: not_udot:
216372
; CHECK: // %bb.0:

0 commit comments

Comments
 (0)