Skip to content

[AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input #120207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 27 additions & 17 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21953,21 +21953,35 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
SDLoc DL(N);

SDValue Op2 = N->getOperand(2);
if (Op2->getOpcode() != ISD::MUL ||
!ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
!ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
return SDValue();
unsigned Op2Opcode = Op2->getOpcode();
SDValue MulOpLHS, MulOpRHS;
bool MulOpLHSIsSigned, MulOpRHSIsSigned;
if (ISD::isExtOpcode(Op2Opcode)) {
MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
MulOpLHS = Op2->getOperand(0);
MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
} else if (Op2Opcode == ISD::MUL) {
SDValue ExtMulOpLHS = Op2->getOperand(0);
SDValue ExtMulOpRHS = Op2->getOperand(1);

unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
return SDValue();

SDValue Acc = N->getOperand(1);
SDValue Mul = N->getOperand(2);
SDValue ExtMulOpLHS = Mul->getOperand(0);
SDValue ExtMulOpRHS = Mul->getOperand(1);
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;

SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
MulOpLHS = ExtMulOpLHS->getOperand(0);
MulOpRHS = ExtMulOpRHS->getOperand(0);

if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
return SDValue();
} else
return SDValue();

SDValue Acc = N->getOperand(1);
EVT ReducedVT = N->getValueType(0);
EVT MulSrcVT = MulOpLHS.getValueType();

Expand All @@ -21981,8 +21995,6 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();

bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
// If the extensions are mixed, we should lower it to a usdot instead
unsigned Opcode = 0;
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
Expand All @@ -21998,10 +22010,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// USDOT expects the signed operand to be last
if (!MulOpRHSIsSigned)
std::swap(MulOpLHS, MulOpRHS);
} else if (MulOpLHSIsSigned)
Opcode = AArch64ISD::SDOT;
else
Opcode = AArch64ISD::UDOT;
} else
Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;

// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
Expand Down
248 changes: 248 additions & 0 deletions llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,166 @@ entry:
ret <4 x i64> %partial.reduce
}

define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: udot_no_bin_op:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.16b, #1
; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: udot_no_bin_op:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-NODOT-NEXT: ushll v3.4s, v1.4h, #0
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
}

define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: sdot_no_bin_op:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.16b, #1
; CHECK-DOT-NEXT: sdot v0.4s, v1.16b, v2.16b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: sdot_no_bin_op:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: sshll v2.8h, v1.8b, #0
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
; CHECK-NODOT-NEXT: sshll v3.4s, v1.4h, #0
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v2.4h
; CHECK-NODOT-NEXT: saddw2 v2.4s, v3.4s, v2.8h
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
; CHECK-NODOT-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i32>
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
ret <4 x i32> %partial.reduce
}

define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-DOT-LABEL: udot_no_bin_op_narrow:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.8b, #1
; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: udot_no_bin_op_narrow:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%a.wide = zext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
}

define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
; CHECK-DOT-LABEL: sdot_no_bin_op_narrow:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v2.8b, #1
; CHECK-DOT-NEXT: sdot v0.2s, v1.8b, v2.8b
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
; CHECK-NODOT-NEXT: ret
%a.wide = sext <8 x i8> %a to <8 x i32>
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
ret <2 x i32> %partial.reduce
}

define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: udot_no_bin_op_8to64:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v3.16b, #1
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: udot_no_bin_op_8to64:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: ushll v3.8h, v2.8b, #0
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: ushll v4.4s, v3.4h, #0
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: ushll2 v3.4s, v3.8h, #0
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v4.2s
; CHECK-NODOT-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
; CHECK-NODOT-NEXT: uaddl v3.2d, v3.2s, v5.2s
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
%a.wide = zext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
}

define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
; CHECK-DOT-LABEL: sdot_no_bin_op_8to64:
; CHECK-DOT: // %bb.0:
; CHECK-DOT-NEXT: movi v3.16b, #1
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-DOT-NEXT: ret
;
; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64:
; CHECK-NODOT: // %bb.0:
; CHECK-NODOT-NEXT: sshll v3.8h, v2.8b, #0
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
; CHECK-NODOT-NEXT: sshll v4.4s, v3.4h, #0
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
; CHECK-NODOT-NEXT: sshll2 v3.4s, v3.8h, #0
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v4.2s
; CHECK-NODOT-NEXT: saddl2 v4.2d, v3.4s, v5.4s
; CHECK-NODOT-NEXT: saddl v3.2d, v3.2s, v5.2s
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
; CHECK-NODOT-NEXT: ret
%a.wide = sext <16 x i8> %a to <16 x i64>
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
ret <4 x i64> %partial.reduce
}

define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
Expand Down Expand Up @@ -398,3 +558,91 @@ define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) {
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
ret <2 x i32> %partial.reduce
}

define <2 x i64> @udot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: udot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEXT: ushll v3.4s, v1.4h, #0
; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NEXT: umull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: umlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: umlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: umlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = zext <8 x i16> %a to <8 x i64>
%b.wide = zext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}

define <2 x i64> @sdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: sdot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NEXT: sshll v3.4s, v1.4h, #0
; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0
; CHECK-NEXT: sshll v4.4s, v2.4h, #0
; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = sext <8 x i16> %a to <8 x i64>
%b.wide = sext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}

define <2 x i64> @usdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: usdot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
; CHECK-NEXT: ushll v3.4s, v1.4h, #0
; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-NEXT: sshll v4.4s, v2.4h, #0
; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = zext <8 x i16> %a to <8 x i64>
%b.wide = sext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}

define <2 x i64> @sudot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
; CHECK-LABEL: sudot_different_types:
; CHECK: // %bb.0: // %entry
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
; CHECK-NEXT: sshll v3.4s, v1.4h, #0
; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0
; CHECK-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
; CHECK-NEXT: ret
entry:
%a.wide = sext <8 x i16> %a to <8 x i64>
%b.wide = zext <8 x i8> %b to <8 x i64>
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
ret <2 x i64> %partial.reduce
}
Loading
Loading