Skip to content

[AArch64][NEON][SVE] Lower i8 to i64 partial reduction to a dot product #110220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1996,8 +1996,8 @@ bool AArch64TargetLowering::shouldExpandPartialReductionIntrinsic(
return true;

EVT VT = EVT::getEVT(I->getType());
return VT != MVT::nxv4i32 && VT != MVT::nxv2i64 && VT != MVT::v4i32 &&
VT != MVT::v2i32;
return VT != MVT::nxv4i64 && VT != MVT::nxv4i32 && VT != MVT::nxv2i64 &&
VT != MVT::v4i64 && VT != MVT::v4i32 && VT != MVT::v2i32;
}

bool AArch64TargetLowering::shouldExpandCttzElements(EVT VT) const {
Expand Down Expand Up @@ -21918,8 +21918,10 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,

// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
if (!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
!(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
return SDValue();
Expand All @@ -21932,7 +21934,7 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,

bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
if (Scalable && ReducedType != MVT::nxv4i32)
if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
return SDValue();

Opcode = AArch64ISD::USDOT;
Expand All @@ -21944,6 +21946,20 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
else
Opcode = AArch64ISD::UDOT;

// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
EVT ReducedTypeI32 =
(ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;

auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
DAG.getConstant(0, DL, ReducedTypeI32), A, B);
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
Extended);
}

return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
}

Expand Down
156 changes: 156 additions & 0 deletions llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,162 @@ define <2 x i32> @sudot_narrow(<2 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
ret <2 x i32> %partial.reduce
}

define <4 x i64> @udot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
; CHECK-DOT-LABEL: udot_8to64:
; CHECK-DOT: // %bb.0: // %entry
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: udot_8to64:
; CHECK-NODOT: // %bb.0: // %entry
; CHECK-NODOT-NEXT: umull v4.8h, v2.8b, v3.8b
; CHECK-NODOT-NEXT: umull2 v2.8h, v2.16b, v3.16b
; CHECK-NODOT-NEXT: ushll v3.4s, v4.4h, #0
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: ushll2 v4.4s, v4.8h, #0
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v3.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v3.2s
; CHECK-NODOT-NEXT: uaddl2 v3.2d, v4.4s, v5.4s
; CHECK-NODOT-NEXT: uaddl v4.2d, v4.2s, v5.2s
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
%b.wide = zext <16 x i8> %b to <16 x i64>
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
<4 x i64> %acc, <16 x i64> %mult)
ret <4 x i64> %partial.reduce
}

define <4 x i64> @sdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
; CHECK-DOT-LABEL: sdot_8to64:
; CHECK-DOT: // %bb.0: // %entry
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: sdot_8to64:
; CHECK-NODOT: // %bb.0: // %entry
; CHECK-NODOT-NEXT: smull v4.8h, v2.8b, v3.8b
; CHECK-NODOT-NEXT: smull2 v2.8h, v2.16b, v3.16b
; CHECK-NODOT-NEXT: sshll v3.4s, v4.4h, #0
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: sshll2 v4.4s, v4.8h, #0
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v3.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v3.2s
; CHECK-NODOT-NEXT: saddl2 v3.2d, v4.4s, v5.4s
; CHECK-NODOT-NEXT: saddl v4.2d, v4.2s, v5.2s
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v3.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v4.2d, v0.2d
; CHECK-NODOT-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
%b.wide = sext <16 x i8> %b to <16 x i64>
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
<4 x i64> %acc, <16 x i64> %mult)
ret <4 x i64> %partial.reduce
}

define <4 x i64> @usdot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b){
; CHECK-NOI8MM-LABEL: usdot_8to64:
; CHECK-NOI8MM: // %bb.0: // %entry
; CHECK-NOI8MM-NEXT: ushll v4.8h, v2.8b, #0
; CHECK-NOI8MM-NEXT: sshll v5.8h, v3.8b, #0
; CHECK-NOI8MM-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NOI8MM-NEXT: sshll2 v3.8h, v3.16b, #0
; CHECK-NOI8MM-NEXT: ushll v6.4s, v4.4h, #0
; CHECK-NOI8MM-NEXT: sshll v7.4s, v5.4h, #0
; CHECK-NOI8MM-NEXT: ushll2 v4.4s, v4.8h, #0
; CHECK-NOI8MM-NEXT: sshll2 v5.4s, v5.8h, #0
; CHECK-NOI8MM-NEXT: ushll2 v16.4s, v2.8h, #0
; CHECK-NOI8MM-NEXT: sshll2 v17.4s, v3.8h, #0
; CHECK-NOI8MM-NEXT: ushll v2.4s, v2.4h, #0
; CHECK-NOI8MM-NEXT: sshll v3.4s, v3.4h, #0
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
; CHECK-NOI8MM-NEXT: ret
;
; CHECK-I8MM-LABEL: usdot_8to64:
; CHECK-I8MM: // %bb.0: // %entry
; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
; CHECK-I8MM-NEXT: usdot v4.4s, v2.16b, v3.16b
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-I8MM-NEXT: ret
entry:
%a.wide = zext <16 x i8> %a to <16 x i64>
%b.wide = sext <16 x i8> %b to <16 x i64>
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
<4 x i64> %acc, <16 x i64> %mult)
ret <4 x i64> %partial.reduce
}

define <4 x i64> @sudot_8to64(<4 x i64> %acc, <16 x i8> %a, <16 x i8> %b) {
; CHECK-NOI8MM-LABEL: sudot_8to64:
; CHECK-NOI8MM: // %bb.0: // %entry
; CHECK-NOI8MM-NEXT: sshll v4.8h, v2.8b, #0
; CHECK-NOI8MM-NEXT: ushll v5.8h, v3.8b, #0
; CHECK-NOI8MM-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NOI8MM-NEXT: ushll2 v3.8h, v3.16b, #0
; CHECK-NOI8MM-NEXT: sshll v6.4s, v4.4h, #0
; CHECK-NOI8MM-NEXT: ushll v7.4s, v5.4h, #0
; CHECK-NOI8MM-NEXT: sshll2 v4.4s, v4.8h, #0
; CHECK-NOI8MM-NEXT: ushll2 v5.4s, v5.8h, #0
; CHECK-NOI8MM-NEXT: sshll2 v16.4s, v2.8h, #0
; CHECK-NOI8MM-NEXT: ushll2 v17.4s, v3.8h, #0
; CHECK-NOI8MM-NEXT: sshll v2.4s, v2.4h, #0
; CHECK-NOI8MM-NEXT: ushll v3.4s, v3.4h, #0
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v6.4s, v7.4s
; CHECK-NOI8MM-NEXT: smlal v0.2d, v6.2s, v7.2s
; CHECK-NOI8MM-NEXT: smull v18.2d, v4.2s, v5.2s
; CHECK-NOI8MM-NEXT: smull2 v4.2d, v4.4s, v5.4s
; CHECK-NOI8MM-NEXT: smlal2 v1.2d, v16.4s, v17.4s
; CHECK-NOI8MM-NEXT: smlal v0.2d, v16.2s, v17.2s
; CHECK-NOI8MM-NEXT: smlal2 v4.2d, v2.4s, v3.4s
; CHECK-NOI8MM-NEXT: smlal v18.2d, v2.2s, v3.2s
; CHECK-NOI8MM-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NOI8MM-NEXT: add v0.2d, v18.2d, v0.2d
; CHECK-NOI8MM-NEXT: ret
;
; CHECK-I8MM-LABEL: sudot_8to64:
; CHECK-I8MM: // %bb.0: // %entry
; CHECK-I8MM-NEXT: movi v4.2d, #0000000000000000
; CHECK-I8MM-NEXT: usdot v4.4s, v3.16b, v2.16b
; CHECK-I8MM-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-I8MM-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-I8MM-NEXT: ret
entry:
%a.wide = sext <16 x i8> %a to <16 x i64>
%b.wide = zext <16 x i8> %b to <16 x i64>
%mult = mul nuw nsw <16 x i64> %a.wide, %b.wide
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(
<4 x i64> %acc, <16 x i64> %mult)
ret <4 x i64> %partial.reduce
}

define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
Expand Down
Loading
Loading