Skip to content

Commit d73e6e7

Browse files
[AArch64][SVE] Add dot product codegen for partial reductions with
no binary operation on input Add codegen for when the input type has 4 times as many elements as the output type and the input to the partial reduction does not have a binary operation performed on it.
1 parent 411df3b commit d73e6e7

File tree

3 files changed

+288
-32
lines changed

3 files changed

+288
-32
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -21741,45 +21741,63 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2174121741
// The narrower of the two operands. Used as the accumulator
2174221742
auto NarrowOp = N->getOperand(1);
2174321743
auto MulOp = N->getOperand(2);
21744-
if (MulOp->getOpcode() != ISD::MUL)
21745-
return SDValue();
2174621744

21747-
auto ExtA = MulOp->getOperand(0);
21748-
auto ExtB = MulOp->getOperand(1);
21745+
unsigned MulOpcode = MulOp->getOpcode();
21746+
EVT ReducedVT = N->getValueType(0);
21747+
EVT MulOpVT = MulOp->getValueType(0);
21748+
unsigned Opcode = 0;
21749+
bool AIsSigned, BIsSigned;
21750+
SDValue A, B;
21751+
if (MulOpcode != ISD::MUL && ReducedVT.getVectorElementCount() * 4 ==
21752+
MulOpVT.getVectorElementCount()) {
21753+
if (!ISD::isExtOpcode(MulOpcode))
21754+
return SDValue();
21755+
AIsSigned = MulOpcode == ISD::SIGN_EXTEND;
21756+
BIsSigned = AIsSigned;
21757+
SDValue NewMulOp = MulOp->getOperand(0);
21758+
Opcode = AIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
21759+
A = NewMulOp;
21760+
B = DAG.getConstant(1, DL, NewMulOp.getValueType());
2174921761

21750-
if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
21751-
!ISD::isExtOpcode(ExtB->getOpcode()))
21752-
return SDValue();
21753-
bool AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21754-
bool BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
21762+
} else {
21763+
if (MulOp->getOpcode() != ISD::MUL)
21764+
return SDValue();
2175521765

21756-
auto A = ExtA->getOperand(0);
21757-
auto B = ExtB->getOperand(0);
21758-
if (A.getValueType() != B.getValueType())
21759-
return SDValue();
21766+
auto ExtA = MulOp->getOperand(0);
21767+
auto ExtB = MulOp->getOperand(1);
2176021768

21761-
EVT ReducedType = N->getValueType(0);
21762-
EVT MulSrcType = A.getValueType();
21769+
if (!ISD::isExtOpcode(ExtA->getOpcode()) ||
21770+
!ISD::isExtOpcode(ExtB->getOpcode()))
21771+
return SDValue();
21772+
AIsSigned = ExtA->getOpcode() == ISD::SIGN_EXTEND;
21773+
BIsSigned = ExtB->getOpcode() == ISD::SIGN_EXTEND;
21774+
21775+
A = ExtA->getOperand(0);
21776+
B = ExtB->getOperand(0);
21777+
if (A.getValueType() != B.getValueType())
21778+
return SDValue();
21779+
}
21780+
21781+
EVT MulSrcVT = A.getValueType();
2176321782

2176421783
// Dot products operate on chunks of four elements so there must be four times
2176521784
// as many elements in the wide type
21766-
if (!(ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) &&
21767-
!(ReducedType == MVT::nxv4i32 && MulSrcType == MVT::nxv16i8) &&
21768-
!(ReducedType == MVT::nxv2i64 && MulSrcType == MVT::nxv8i16) &&
21769-
!(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8) &&
21770-
!(ReducedType == MVT::v4i32 && MulSrcType == MVT::v16i8) &&
21771-
!(ReducedType == MVT::v2i32 && MulSrcType == MVT::v8i8))
21785+
if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
21786+
!(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
21787+
!(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
21788+
!(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
21789+
!(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
21790+
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
2177221791
return SDValue();
2177321792

2177421793
// If the extensions are mixed, we should lower it to a usdot instead
21775-
unsigned Opcode = 0;
2177621794
if (AIsSigned != BIsSigned) {
2177721795
if (!Subtarget->hasMatMulInt8())
2177821796
return SDValue();
2177921797

2178021798
bool Scalable = N->getValueType(0).isScalableVT();
2178121799
// There's no nxv2i64 version of usdot
21782-
if (Scalable && ReducedType != MVT::nxv4i32 && ReducedType != MVT::nxv4i64)
21800+
if (Scalable && ReducedVT != MVT::nxv4i32 && ReducedVT != MVT::nxv4i64)
2178321801
return SDValue();
2178421802

2178521803
Opcode = AArch64ISD::USDOT;
@@ -21793,19 +21811,19 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2179321811

2179421812
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
2179521813
// product followed by a zero / sign extension
21796-
if ((ReducedType == MVT::nxv4i64 && MulSrcType == MVT::nxv16i8) ||
21797-
(ReducedType == MVT::v4i64 && MulSrcType == MVT::v16i8)) {
21798-
EVT ReducedTypeI32 =
21799-
(ReducedType.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
21800-
21801-
auto DotI32 = DAG.getNode(Opcode, DL, ReducedTypeI32,
21802-
DAG.getConstant(0, DL, ReducedTypeI32), A, B);
21803-
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedType);
21814+
if ((ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) ||
21815+
(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8)) {
21816+
EVT ReducedVTI32 =
21817+
(ReducedVT.isScalableVector()) ? MVT::nxv4i32 : MVT::v4i32;
21818+
21819+
auto DotI32 = DAG.getNode(Opcode, DL, ReducedVTI32,
21820+
DAG.getConstant(0, DL, ReducedVTI32), A, B);
21821+
auto Extended = DAG.getSExtOrTrunc(DotI32, DL, ReducedVT);
2180421822
return DAG.getNode(ISD::ADD, DL, NarrowOp.getValueType(), NarrowOp,
2180521823
Extended);
2180621824
}
2180721825

21808-
return DAG.getNode(Opcode, DL, ReducedType, NarrowOp, A, B);
21826+
return DAG.getNode(Opcode, DL, ReducedVT, NarrowOp, A, B);
2180921827
}
2181021828

2181121829
SDValue tryLowerPartialReductionToWideAdd(SDNode *N,

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,166 @@ entry:
367367
ret <4 x i64> %partial.reduce
368368
}
369369

370+
define <4 x i32> @udot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
371+
; CHECK-DOT-LABEL: udot_no_bin_op:
372+
; CHECK-DOT: // %bb.0:
373+
; CHECK-DOT-NEXT: movi v2.16b, #1
374+
; CHECK-DOT-NEXT: udot v0.4s, v1.16b, v2.16b
375+
; CHECK-DOT-NEXT: ret
376+
;
377+
; CHECK-NODOT-LABEL: udot_no_bin_op:
378+
; CHECK-NODOT: // %bb.0:
379+
; CHECK-NODOT-NEXT: ushll v2.8h, v1.8b, #0
380+
; CHECK-NODOT-NEXT: ushll2 v1.8h, v1.16b, #0
381+
; CHECK-NODOT-NEXT: ushll v3.4s, v1.4h, #0
382+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v2.4h
383+
; CHECK-NODOT-NEXT: uaddw2 v2.4s, v3.4s, v2.8h
384+
; CHECK-NODOT-NEXT: uaddw2 v0.4s, v0.4s, v1.8h
385+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
386+
; CHECK-NODOT-NEXT: ret
387+
%a.wide = zext <16 x i8> %a to <16 x i32>
388+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
389+
ret <4 x i32> %partial.reduce
390+
}
391+
392+
define <4 x i32> @sdot_no_bin_op(<4 x i32> %acc, <16 x i8> %a){
393+
; CHECK-DOT-LABEL: sdot_no_bin_op:
394+
; CHECK-DOT: // %bb.0:
395+
; CHECK-DOT-NEXT: movi v2.16b, #1
396+
; CHECK-DOT-NEXT: sdot v0.4s, v1.16b, v2.16b
397+
; CHECK-DOT-NEXT: ret
398+
;
399+
; CHECK-NODOT-LABEL: sdot_no_bin_op:
400+
; CHECK-NODOT: // %bb.0:
401+
; CHECK-NODOT-NEXT: sshll v2.8h, v1.8b, #0
402+
; CHECK-NODOT-NEXT: sshll2 v1.8h, v1.16b, #0
403+
; CHECK-NODOT-NEXT: sshll v3.4s, v1.4h, #0
404+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v2.4h
405+
; CHECK-NODOT-NEXT: saddw2 v2.4s, v3.4s, v2.8h
406+
; CHECK-NODOT-NEXT: saddw2 v0.4s, v0.4s, v1.8h
407+
; CHECK-NODOT-NEXT: add v0.4s, v2.4s, v0.4s
408+
; CHECK-NODOT-NEXT: ret
409+
%a.wide = sext <16 x i8> %a to <16 x i32>
410+
%partial.reduce = tail call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> %acc, <16 x i32> %a.wide)
411+
ret <4 x i32> %partial.reduce
412+
}
413+
414+
define <2 x i32> @udot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
415+
; CHECK-DOT-LABEL: udot_no_bin_op_narrow:
416+
; CHECK-DOT: // %bb.0:
417+
; CHECK-DOT-NEXT: movi v2.8b, #1
418+
; CHECK-DOT-NEXT: udot v0.2s, v1.8b, v2.8b
419+
; CHECK-DOT-NEXT: ret
420+
;
421+
; CHECK-NODOT-LABEL: udot_no_bin_op_narrow:
422+
; CHECK-NODOT: // %bb.0:
423+
; CHECK-NODOT-NEXT: ushll v1.8h, v1.8b, #0
424+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
425+
; CHECK-NODOT-NEXT: ushll v2.4s, v1.4h, #0
426+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v1.8h, #0
427+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
428+
; CHECK-NODOT-NEXT: uaddw v0.4s, v0.4s, v1.4h
429+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
430+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
431+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
432+
; CHECK-NODOT-NEXT: uaddw v1.4s, v2.4s, v4.4h
433+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
434+
; CHECK-NODOT-NEXT: ret
435+
%a.wide = zext <8 x i8> %a to <8 x i32>
436+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
437+
ret <2 x i32> %partial.reduce
438+
}
439+
440+
define <2 x i32> @sdot_no_bin_op_narrow(<2 x i32> %acc, <8 x i8> %a){
441+
; CHECK-DOT-LABEL: sdot_no_bin_op_narrow:
442+
; CHECK-DOT: // %bb.0:
443+
; CHECK-DOT-NEXT: movi v2.8b, #1
444+
; CHECK-DOT-NEXT: sdot v0.2s, v1.8b, v2.8b
445+
; CHECK-DOT-NEXT: ret
446+
;
447+
; CHECK-NODOT-LABEL: sdot_no_bin_op_narrow:
448+
; CHECK-NODOT: // %bb.0:
449+
; CHECK-NODOT-NEXT: sshll v1.8h, v1.8b, #0
450+
; CHECK-NODOT-NEXT: // kill: def $d0 killed $d0 def $q0
451+
; CHECK-NODOT-NEXT: sshll v2.4s, v1.4h, #0
452+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v1.8h, #0
453+
; CHECK-NODOT-NEXT: ext v4.16b, v1.16b, v1.16b, #8
454+
; CHECK-NODOT-NEXT: saddw v0.4s, v0.4s, v1.4h
455+
; CHECK-NODOT-NEXT: ext v3.16b, v3.16b, v3.16b, #8
456+
; CHECK-NODOT-NEXT: ext v2.16b, v2.16b, v2.16b, #8
457+
; CHECK-NODOT-NEXT: add v0.2s, v3.2s, v0.2s
458+
; CHECK-NODOT-NEXT: saddw v1.4s, v2.4s, v4.4h
459+
; CHECK-NODOT-NEXT: add v0.2s, v1.2s, v0.2s
460+
; CHECK-NODOT-NEXT: ret
461+
%a.wide = sext <8 x i8> %a to <8 x i32>
462+
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v2i32.v8i32(<2 x i32> %acc, <8 x i32> %a.wide)
463+
ret <2 x i32> %partial.reduce
464+
}
465+
466+
define <4 x i64> @udot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
467+
; CHECK-DOT-LABEL: udot_no_bin_op_8to64:
468+
; CHECK-DOT: // %bb.0:
469+
; CHECK-DOT-NEXT: movi v3.16b, #1
470+
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
471+
; CHECK-DOT-NEXT: udot v4.4s, v2.16b, v3.16b
472+
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
473+
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
474+
; CHECK-DOT-NEXT: ret
475+
;
476+
; CHECK-NODOT-LABEL: udot_no_bin_op_8to64:
477+
; CHECK-NODOT: // %bb.0:
478+
; CHECK-NODOT-NEXT: ushll v3.8h, v2.8b, #0
479+
; CHECK-NODOT-NEXT: ushll2 v2.8h, v2.16b, #0
480+
; CHECK-NODOT-NEXT: ushll v4.4s, v3.4h, #0
481+
; CHECK-NODOT-NEXT: ushll v5.4s, v2.4h, #0
482+
; CHECK-NODOT-NEXT: ushll2 v3.4s, v3.8h, #0
483+
; CHECK-NODOT-NEXT: ushll2 v2.4s, v2.8h, #0
484+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v4.4s
485+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v4.2s
486+
; CHECK-NODOT-NEXT: uaddl2 v4.2d, v3.4s, v5.4s
487+
; CHECK-NODOT-NEXT: uaddl v3.2d, v3.2s, v5.2s
488+
; CHECK-NODOT-NEXT: uaddw2 v1.2d, v1.2d, v2.4s
489+
; CHECK-NODOT-NEXT: uaddw v0.2d, v0.2d, v2.2s
490+
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
491+
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
492+
; CHECK-NODOT-NEXT: ret
493+
%a.wide = zext <16 x i8> %a to <16 x i64>
494+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
495+
ret <4 x i64> %partial.reduce
496+
}
497+
498+
define <4 x i64> @sdot_no_bin_op_8to64(<4 x i64> %acc, <16 x i8> %a){
499+
; CHECK-DOT-LABEL: sdot_no_bin_op_8to64:
500+
; CHECK-DOT: // %bb.0:
501+
; CHECK-DOT-NEXT: movi v3.16b, #1
502+
; CHECK-DOT-NEXT: movi v4.2d, #0000000000000000
503+
; CHECK-DOT-NEXT: sdot v4.4s, v2.16b, v3.16b
504+
; CHECK-DOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
505+
; CHECK-DOT-NEXT: saddw v0.2d, v0.2d, v4.2s
506+
; CHECK-DOT-NEXT: ret
507+
;
508+
; CHECK-NODOT-LABEL: sdot_no_bin_op_8to64:
509+
; CHECK-NODOT: // %bb.0:
510+
; CHECK-NODOT-NEXT: sshll v3.8h, v2.8b, #0
511+
; CHECK-NODOT-NEXT: sshll2 v2.8h, v2.16b, #0
512+
; CHECK-NODOT-NEXT: sshll v4.4s, v3.4h, #0
513+
; CHECK-NODOT-NEXT: sshll v5.4s, v2.4h, #0
514+
; CHECK-NODOT-NEXT: sshll2 v3.4s, v3.8h, #0
515+
; CHECK-NODOT-NEXT: sshll2 v2.4s, v2.8h, #0
516+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v4.4s
517+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v4.2s
518+
; CHECK-NODOT-NEXT: saddl2 v4.2d, v3.4s, v5.4s
519+
; CHECK-NODOT-NEXT: saddl v3.2d, v3.2s, v5.2s
520+
; CHECK-NODOT-NEXT: saddw2 v1.2d, v1.2d, v2.4s
521+
; CHECK-NODOT-NEXT: saddw v0.2d, v0.2d, v2.2s
522+
; CHECK-NODOT-NEXT: add v1.2d, v4.2d, v1.2d
523+
; CHECK-NODOT-NEXT: add v0.2d, v3.2d, v0.2d
524+
; CHECK-NODOT-NEXT: ret
525+
%a.wide = sext <16 x i8> %a to <16 x i64>
526+
%partial.reduce = tail call <4 x i64> @llvm.experimental.vector.partial.reduce.add.v4i64.v16i64(<4 x i64> %acc, <16 x i64> %a.wide)
527+
ret <4 x i64> %partial.reduce
528+
}
529+
370530
define <4 x i32> @not_udot(<4 x i32> %acc, <8 x i8> %u, <8 x i8> %s) #0{
371531
; CHECK-LABEL: not_udot:
372532
; CHECK: // %bb.0:

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,84 @@ entry:
316316
ret <vscale x 4 x i64> %partial.reduce
317317
}
318318

319+
define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
320+
; CHECK-LABEL: udot_no_bin_op:
321+
; CHECK: // %bb.0:
322+
; CHECK-NEXT: mov z2.b, #1 // =0x1
323+
; CHECK-NEXT: udot z0.s, z1.b, z2.b
324+
; CHECK-NEXT: ret
325+
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
326+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
327+
ret <vscale x 4 x i32> %partial.reduce
328+
}
329+
330+
define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a){
331+
; CHECK-LABEL: sdot_no_bin_op:
332+
; CHECK: // %bb.0:
333+
; CHECK-NEXT: mov z2.b, #1 // =0x1
334+
; CHECK-NEXT: sdot z0.s, z1.b, z2.b
335+
; CHECK-NEXT: ret
336+
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
337+
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
338+
ret <vscale x 4 x i32> %partial.reduce
339+
}
340+
341+
define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
342+
; CHECK-LABEL: udot_no_bin_op_wide:
343+
; CHECK: // %bb.0: // %entry
344+
; CHECK-NEXT: mov z2.h, #1 // =0x1
345+
; CHECK-NEXT: udot z0.d, z1.h, z2.h
346+
; CHECK-NEXT: ret
347+
entry:
348+
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
349+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
350+
ret <vscale x 2 x i64> %partial.reduce
351+
}
352+
353+
define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i16> %b){
354+
; CHECK-LABEL: sdot_no_bin_op_wide:
355+
; CHECK: // %bb.0: // %entry
356+
; CHECK-NEXT: mov z2.h, #1 // =0x1
357+
; CHECK-NEXT: sdot z0.d, z1.h, z2.h
358+
; CHECK-NEXT: ret
359+
entry:
360+
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
361+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %a.wide)
362+
ret <vscale x 2 x i64> %partial.reduce
363+
}
364+
365+
define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
366+
; CHECK-LABEL: udot_no_bin_op_8to64:
367+
; CHECK: // %bb.0:
368+
; CHECK-NEXT: mov z3.b, #1 // =0x1
369+
; CHECK-NEXT: mov z4.s, #0 // =0x0
370+
; CHECK-NEXT: udot z4.s, z2.b, z3.b
371+
; CHECK-NEXT: sunpklo z2.d, z4.s
372+
; CHECK-NEXT: sunpkhi z3.d, z4.s
373+
; CHECK-NEXT: add z0.d, z0.d, z2.d
374+
; CHECK-NEXT: add z1.d, z1.d, z3.d
375+
; CHECK-NEXT: ret
376+
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
377+
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
378+
ret <vscale x 4 x i64> %partial.reduce
379+
}
380+
381+
define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8> %a){
382+
; CHECK-LABEL: sdot_no_bin_op_8to64:
383+
; CHECK: // %bb.0:
384+
; CHECK-NEXT: mov z3.b, #1 // =0x1
385+
; CHECK-NEXT: mov z4.s, #0 // =0x0
386+
; CHECK-NEXT: sdot z4.s, z2.b, z3.b
387+
; CHECK-NEXT: sunpklo z2.d, z4.s
388+
; CHECK-NEXT: sunpkhi z3.d, z4.s
389+
; CHECK-NEXT: add z0.d, z0.d, z2.d
390+
; CHECK-NEXT: add z1.d, z1.d, z3.d
391+
; CHECK-NEXT: ret
392+
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
393+
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
394+
ret <vscale x 4 x i64> %partial.reduce
395+
}
396+
319397
define <vscale x 4 x i32> @not_udot(<vscale x 4 x i32> %acc, <vscale x 8 x i8> %a, <vscale x 8 x i8> %b) {
320398
; CHECK-LABEL: not_udot:
321399
; CHECK: // %bb.0: // %entry

0 commit comments

Comments
 (0)