-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[AArch64][SVE] Add dot product codegen for partial reductions with no binary operation on input #120207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-aarch64 Author: James Chesterman (JamesChesterman) ChangesAdd codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it. Full diff: https://github.com/llvm/llvm-project/pull/120207.diff 3 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index d1354ccf376609..1e5b80174e8cfd 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -21741,45 +21741,63 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// The narrower of the two operands. Used as the accumulator
auto NarrowOp = N->getOperand(1);
auto MulOp = N->getOperand(2);
- if (MulOp->getOpcode() != ISD::MUL)
- return SDValue();
- auto ExtA = MulOp->getOperand(0);
- auto ExtB = MulOp->getOperand(1);
+ unsigned MulOpcode = MulOp->getOpcode();
+ EVT ReducedVT = N->getValueType(0);
+ EVT MulOpVT = MulOp->getValueType(0);
+ unsigned Opcode = 0;
+ bool AIsSigned, BIsSigned;
+ SDValue A, B;
+ if (MulOpcode != ISD::MUL && ReducedVT.getVectorElementCount() * 4 ==
+ MulOpVT.getVectorElementCount()) {
+ if (!ISD::isExtOpcode(MulOpcode))
+ return SDValue();
+ AIsSigned = MulOpcode == ISD::SIGN_EXTEND;
+ BIsSigned = AIsSigned;
+ SDValue NewMulOp = MulOp->getOperand(0);
+ Opcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
+ A = NewMulOp;
+ B = DAG.getConstant(1, DL, NewMulOp.getValueType());
- if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
- !ISD::isExtOpcode(ExtB->getOpcode()))
- return SDValue();
- bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
- bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+ } else {
+ if (MulOp->getOpcode() != ISD::MUL)
+ return SDValue();
- auto A = ExtA->getOperand(0);
- auto B = ExtB->getOperand(0);
- if (A.getValueType() != B.getValueType())
- return SDValue();
+ auto ExtA = MulOp->getOperand(0);
+ auto ExtB = MulOp->getOperand(1);
- EVT ReducedType = N->getValueType(0);
- EVT MulSrcType = A.getValueType();
+ if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
+ !ISD::isExtOpcode(ExtB->getOpcode()))
+ return SDValue();
+ AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
+ BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
+
+ A = ExtA->getOperand(0);
+ B = ExtB->getOperand(0);
+ if (A.getValueType() != B.getValueType())
+ return SDValue();
+ }
+
+ EVT MulSrcVT = A.getValueType();
// Dot products operate on chunks of four elements so there must be four times
// as many elements in the wide type
- if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
- !(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
- !(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
- !(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
+ if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
+ !(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
+ !(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
+ !(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
return SDValue();
// If the extensions are mixed, we should lower it to a usdot instead
- unsigned Opcode = 0;
if (AIsSigned != BIsSigned) {
if (!Subtarget->hasMatMulInt8())
return SDValue();
bool Scalable = N->getValueType(0).isScalableVT();
// There's no nxv2i64 version of usdot
- if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
+ if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
return SDValue();
Opcode = AArch64ISD::USDOT;
@@ -21793,19 +21811,19 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
// product followed by a zero / sign extension
- if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
- (ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
- EVT ReducedTypeI32 =
- (ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
-
- auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
- DAG.getConstant(0, DL, ReducedTypeI32), A, B);
- auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
+ if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
+ (ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
+ EVT ReducedVTI32 =
+ (ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
+
+ auto DotI32 = DAG.getNode(Opcode, DL, ReducedVTI32,
+ DAG.getConstant(0, DL, ReducedVTI32), A, B);
+ auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
Extended);
}
- return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
+ return DAG.getNode(Opcode, DL, ReducedVT, NarrowOp, A, B);
}
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,
diff --git a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
index c1b9a4c9dbb797..c0987582efc261 100644
--- a/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll
@@ -367,6 +367,166 @@ entry:
ret <4 x i64> %partial.reduce
}
+define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #1
+; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: ushll v3.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = zext <16 x i8> %a to <16 x i32>
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
+ ret <4 x i32> %partial.reduce
+}
+
+define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.16b, #1
+; CHECK-DOT-NEXT: sdot v0.4s, v1.16b, v2.16b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v2.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
+; CHECK-NODOT-NEXT: sshll v3.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v2.4h
+; CHECK-NODOT-NEXT: saddw2 v2.4s, v3.4s, v2.8h
+; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
+; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = sext <16 x i8> %a to <16 x i32>
+ %partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
+ ret <4 x i32> %partial.reduce
+}
+
+define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.8b, #1
+; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = zext <8 x i8> %a to <8 x i32>
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
+ ret <2 x i32> %partial.reduce
+}
+
+define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op_narrow:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v2.8b, #1
+; CHECK-DOT-NEXT: sdot v0.2s, v1.8b, v2.8b
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
+; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
+; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
+; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
+; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
+; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
+; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
+; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
+; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
+; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
+; CHECK-NODOT-NEXT: ret
+ %a.wide = sext <8 x i8> %a to <8 x i32>
+ %partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
+ ret <2 x i32> %partial.reduce
+}
+
+define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: udot_no_bin_op_8to64:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v3.16b, #1
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: udot_no_bin_op_8to64:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: ushll v3.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: ushll v4.4s, v3.4h, #0
+; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: ushll2 v3.4s, v3.8h, #0
+; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v4.2s
+; CHECK-NODOT-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NODOT-NEXT: uaddl v3.2d, v3.2s, v5.2s
+; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+ %a.wide = zext <16 x i8> %a to <16 x i64>
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
+ ret <4 x i64> %partial.reduce
+}
+
+define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
+; CHECK-DOT-LABEL: sdot_no_bin_op_8to64:
+; CHECK-DOT: // %bb.0:
+; CHECK-DOT-NEXT: movi v3.16b, #1
+; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
+; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
+; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-DOT-NEXT: ret
+;
+; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64:
+; CHECK-NODOT: // %bb.0:
+; CHECK-NODOT-NEXT: sshll v3.8h, v2.8b, #0
+; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
+; CHECK-NODOT-NEXT: sshll v4.4s, v3.4h, #0
+; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
+; CHECK-NODOT-NEXT: sshll2 v3.4s, v3.8h, #0
+; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v4.2s
+; CHECK-NODOT-NEXT: saddl2 v4.2d, v3.4s, v5.4s
+; CHECK-NODOT-NEXT: saddl v3.2d, v3.2s, v5.2s
+; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
+; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
+; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
+; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
+; CHECK-NODOT-NEXT: ret
+ %a.wide = sext <16 x i8> %a to <16 x i64>
+ %partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
+ ret <4 x i64> %partial.reduce
+}
+
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0:
diff --git a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
index 66d6e0388bbf94..56cd2c9d62b04f 100644
--- a/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
+++ b/llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll
@@ -316,6 +316,84 @@ entry:
ret <vscale x 4 x i64> %partial.reduce
}
+define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: udot_no_bin_op:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEXT: udot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: sdot_no_bin_op:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z2.b, #1 // =0x1
+; CHECK-NEXT: sdot z0.s, z1.b, z2.b
+; CHECK-NEXT: ret
+ %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
+ %partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
+ ret <vscale x 4 x i32> %partial.reduce
+}
+
+define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
+; CHECK-LABEL: udot_no_bin_op_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEXT: udot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
+; CHECK-LABEL: sdot_no_bin_op_wide:
+; CHECK: // %bb.0: // %entry
+; CHECK-NEXT: mov z2.h, #1 // =0x1
+; CHECK-NEXT: sdot z0.d, z1.h, z2.h
+; CHECK-NEXT: ret
+entry:
+ %a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
+ %partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
+ ret <vscale x 2 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: udot_no_bin_op_8to64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: udot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+ %a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
+define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
+; CHECK-LABEL: sdot_no_bin_op_8to64:
+; CHECK: // %bb.0:
+; CHECK-NEXT: mov z3.b, #1 // =0x1
+; CHECK-NEXT: mov z4.s, #0 // =0x0
+; CHECK-NEXT: sdot z4.s, z2.b, z3.b
+; CHECK-NEXT: sunpklo z2.d, z4.s
+; CHECK-NEXT: sunpkhi z3.d, z4.s
+; CHECK-NEXT: add z0.d, z0.d, z2.d
+; CHECK-NEXT: add z1.d, z1.d, z3.d
+; CHECK-NEXT: ret
+ %a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
+ %partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
+ ret <vscale x 4 x i64> %partial.reduce
+}
+
define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
; CHECK-LABEL: not_udot:
; CHECK: // %bb.0: // %entry
|
no binary operation on input Add codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it.
renaming variables.
d73e6e7
to
cc4fb09
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me 👍 See what Sander thinks though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a nit:
is not performed where it is not necessary.
… binary operation on input (llvm#120207) Add codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it.
Add codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it.