Skip to content

Commit cc4fb09

Browse files
Address comments. This is a rebase on the NFC patch added for
renaming variables.
1 parent 0619820 commit cc4fb09

File tree

3 files changed

+249
-23
lines changed

3 files changed

+249
-23
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21953,36 +21953,46 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2195321953
SDLoc DL(N);
2195421954

2195521955
SDValue Op2 = N->getOperand(2);
21956-
if (Op2->getOpcode() != ISD::MUL ||
21957-
!ISD::isExtOpcode(Op2->getOperand(0)->getOpcode()) ||
21958-
!ISD::isExtOpcode(Op2->getOperand(1)->getOpcode()))
21959-
return SDValue();
21956+
unsigned Op2Opcode = Op2->getOpcode();
21957+
SDValue MulOpLHS, MulOpRHS;
21958+
bool MulOpLHSIsSigned, MulOpRHSIsSigned;
21959+
if (ISD::isExtOpcode(Op2Opcode)) {
21960+
MulOpLHSIsSigned = MulOpRHSIsSigned = (Op2Opcode == ISD::SIGN_EXTEND);
21961+
MulOpLHS = Op2->getOperand(0);
21962+
MulOpRHS = DAG.getConstant(1, DL, MulOpLHS.getValueType());
21963+
} else if (Op2Opcode == ISD::MUL) {
21964+
SDValue ExtMulOpLHS = Op2->getOperand(0);
21965+
SDValue ExtMulOpRHS = Op2->getOperand(1);
21966+
21967+
unsigned ExtMulOpLHSOpcode = ExtMulOpLHS->getOpcode();
21968+
unsigned ExtMulOpRHSOpcode = ExtMulOpRHS->getOpcode();
21969+
if (!ISD::isExtOpcode(ExtMulOpLHSOpcode) ||
21970+
!ISD::isExtOpcode(ExtMulOpRHSOpcode))
21971+
return SDValue();
2196021972

21961-
SDValue Acc = N->getOperand(1);
21962-
SDValue Mul = N->getOperand(2);
21963-
SDValue ExtMulOpLHS = Mul->getOperand(0);
21964-
SDValue ExtMulOpRHS = Mul->getOperand(1);
21973+
MulOpLHSIsSigned = ExtMulOpLHSOpcode == ISD::SIGN_EXTEND;
21974+
MulOpRHSIsSigned = ExtMulOpRHSOpcode == ISD::SIGN_EXTEND;
2196521975

21966-
SDValue MulOpLHS = ExtMulOpLHS->getOperand(0);
21967-
SDValue MulOpRHS = ExtMulOpRHS->getOperand(0);
21968-
if (MulOpLHS.getValueType() != MulOpRHS.getValueType())
21976+
MulOpLHS = ExtMulOpLHS->getOperand(0);
21977+
MulOpRHS = ExtMulOpRHS->getOperand(0);
21978+
} else
2196921979
return SDValue();
2197021980

21981+
SDValue Acc = N->getOperand(1);
2197121982
EVT ReducedVT = N->getValueType(0);
2197221983
EVT MulSrcVT = MulOpLHS.getValueType();
2197321984

2197421985
// Dot products operate on chunks of four elements so there must be four times
2197521986
// as many elements in the wide type
21976-
if (!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
21977-
!(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
21978-
!(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
21979-
!(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
21980-
!(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
21981-
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8))
21987+
if ((!(ReducedVT == MVT::nxv4i64 && MulSrcVT == MVT::nxv16i8) &&
21988+
!(ReducedVT == MVT::nxv4i32 && MulSrcVT == MVT::nxv16i8) &&
21989+
!(ReducedVT == MVT::nxv2i64 && MulSrcVT == MVT::nxv8i16) &&
21990+
!(ReducedVT == MVT::v4i64 && MulSrcVT == MVT::v16i8) &&
21991+
!(ReducedVT == MVT::v4i32 && MulSrcVT == MVT::v16i8) &&
21992+
!(ReducedVT == MVT::v2i32 && MulSrcVT == MVT::v8i8)) ||
21993+
(MulOpLHS.getValueType() != MulOpRHS.getValueType()))
2198221994
return SDValue();
2198321995

21984-
bool MulOpLHSIsSigned = ExtMulOpLHS->getOpcode() == ISD::SIGN_EXTEND;
21985-
bool MulOpRHSIsSigned = ExtMulOpRHS->getOpcode() == ISD::SIGN_EXTEND;
2198621996
// If the extensions are mixed, we should lower it to a usdot instead
2198721997
unsigned Opcode = 0;
2198821998
if (MulOpLHSIsSigned != MulOpRHSIsSigned) {
@@ -21998,10 +22008,8 @@ SDValue tryLowerPartialReductionToDot(SDNode *N,
2199822008
// USDOT expects the signed operand to be last
2199922009
if (!MulOpRHSIsSigned)
2200022010
std::swap(MulOpLHS, MulOpRHS);
22001-
} else if (MulOpLHSIsSigned)
22002-
Opcode = AArch64ISD::SDOT;
22003-
else
22004-
Opcode = AArch64ISD::UDOT;
22011+
} else
22012+
Opcode = MulOpLHSIsSigned ? AArch64ISD::SDOT : AArch64ISD::UDOT;
2200522013

2200622014
// Partial reduction lowering for (nx)v16i8 to (nx)v4i64 requires an i32 dot
2200722015
// product followed by a zero / sign extension

llvm/test/CodeGen/AArch64/neon-partial-reduce-dot-product.ll

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,3 +558,91 @@ define <2 x i32> @not_udot_narrow(<2 x i32> %acc, <4 x i8> %u, <4 x i8> %s) {
558558
%partial.reduce = tail call <2 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<2 x i32> %acc, <4 x i32> %mult)
559559
ret <2 x i32> %partial.reduce
560560
}
561+
562+
define <2 x i64> @udot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
563+
; CHECK-LABEL: udot_different_types:
564+
; CHECK: // %bb.0: // %entry
565+
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
566+
; CHECK-NEXT: ushll v3.4s, v1.4h, #0
567+
; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0
568+
; CHECK-NEXT: ushll v4.4s, v2.4h, #0
569+
; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0
570+
; CHECK-NEXT: umull v5.2d, v1.2s, v2.2s
571+
; CHECK-NEXT: umlal v0.2d, v3.2s, v4.2s
572+
; CHECK-NEXT: umlal2 v0.2d, v1.4s, v2.4s
573+
; CHECK-NEXT: umlal2 v5.2d, v3.4s, v4.4s
574+
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
575+
; CHECK-NEXT: ret
576+
entry:
577+
%a.wide = zext <8 x i16> %a to <8 x i64>
578+
%b.wide = zext <8 x i8> %b to <8 x i64>
579+
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
580+
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
581+
ret <2 x i64> %partial.reduce
582+
}
583+
584+
define <2 x i64> @sdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
585+
; CHECK-LABEL: sdot_different_types:
586+
; CHECK: // %bb.0: // %entry
587+
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
588+
; CHECK-NEXT: sshll v3.4s, v1.4h, #0
589+
; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0
590+
; CHECK-NEXT: sshll v4.4s, v2.4h, #0
591+
; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0
592+
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
593+
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
594+
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
595+
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
596+
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
597+
; CHECK-NEXT: ret
598+
entry:
599+
%a.wide = sext <8 x i16> %a to <8 x i64>
600+
%b.wide = sext <8 x i8> %b to <8 x i64>
601+
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
602+
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
603+
ret <2 x i64> %partial.reduce
604+
}
605+
606+
define <2 x i64> @usdot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
607+
; CHECK-LABEL: usdot_different_types:
608+
; CHECK: // %bb.0: // %entry
609+
; CHECK-NEXT: sshll v2.8h, v2.8b, #0
610+
; CHECK-NEXT: ushll v3.4s, v1.4h, #0
611+
; CHECK-NEXT: ushll2 v1.4s, v1.8h, #0
612+
; CHECK-NEXT: sshll v4.4s, v2.4h, #0
613+
; CHECK-NEXT: sshll2 v2.4s, v2.8h, #0
614+
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
615+
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
616+
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
617+
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
618+
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
619+
; CHECK-NEXT: ret
620+
entry:
621+
%a.wide = zext <8 x i16> %a to <8 x i64>
622+
%b.wide = sext <8 x i8> %b to <8 x i64>
623+
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
624+
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
625+
ret <2 x i64> %partial.reduce
626+
}
627+
628+
define <2 x i64> @sudot_different_types(<2 x i64> %acc, <8 x i16> %a, <8 x i8> %b){
629+
; CHECK-LABEL: sudot_different_types:
630+
; CHECK: // %bb.0: // %entry
631+
; CHECK-NEXT: ushll v2.8h, v2.8b, #0
632+
; CHECK-NEXT: sshll v3.4s, v1.4h, #0
633+
; CHECK-NEXT: sshll2 v1.4s, v1.8h, #0
634+
; CHECK-NEXT: ushll v4.4s, v2.4h, #0
635+
; CHECK-NEXT: ushll2 v2.4s, v2.8h, #0
636+
; CHECK-NEXT: smull v5.2d, v1.2s, v2.2s
637+
; CHECK-NEXT: smlal v0.2d, v3.2s, v4.2s
638+
; CHECK-NEXT: smlal2 v0.2d, v1.4s, v2.4s
639+
; CHECK-NEXT: smlal2 v5.2d, v3.4s, v4.4s
640+
; CHECK-NEXT: add v0.2d, v5.2d, v0.2d
641+
; CHECK-NEXT: ret
642+
entry:
643+
%a.wide = sext <8 x i16> %a to <8 x i64>
644+
%b.wide = zext <8 x i8> %b to <8 x i64>
645+
%mult = mul nuw nsw <8 x i64> %a.wide, %b.wide
646+
%partial.reduce = tail call <2 x i64> @llvm.experimental.vector.partial.reduce.add.v2i64.v8i64(<2 x i64> %acc, <8 x i64> %mult)
647+
ret <2 x i64> %partial.reduce
648+
}

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,3 +497,133 @@ entry:
497497
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
498498
ret <vscale x 2 x i64> %partial.reduce
499499
}
500+
501+
define <vscale x 2 x i64> @udot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
502+
; CHECK-LABEL: udot_different_types:
503+
; CHECK: // %bb.0: // %entry
504+
; CHECK-NEXT: and z2.h, z2.h, #0xff
505+
; CHECK-NEXT: uunpklo z3.s, z1.h
506+
; CHECK-NEXT: uunpkhi z1.s, z1.h
507+
; CHECK-NEXT: ptrue p0.d
508+
; CHECK-NEXT: uunpklo z4.s, z2.h
509+
; CHECK-NEXT: uunpkhi z2.s, z2.h
510+
; CHECK-NEXT: uunpklo z5.d, z3.s
511+
; CHECK-NEXT: uunpkhi z3.d, z3.s
512+
; CHECK-NEXT: uunpklo z7.d, z1.s
513+
; CHECK-NEXT: uunpkhi z1.d, z1.s
514+
; CHECK-NEXT: uunpklo z6.d, z4.s
515+
; CHECK-NEXT: uunpkhi z4.d, z4.s
516+
; CHECK-NEXT: uunpklo z24.d, z2.s
517+
; CHECK-NEXT: uunpkhi z2.d, z2.s
518+
; CHECK-NEXT: mul z3.d, z3.d, z4.d
519+
; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d
520+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
521+
; CHECK-NEXT: movprfx z1, z3
522+
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
523+
; CHECK-NEXT: add z0.d, z1.d, z0.d
524+
; CHECK-NEXT: ret
525+
entry:
526+
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
527+
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i64>
528+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
529+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
530+
ret <vscale x 2 x i64> %partial.reduce
531+
}
532+
533+
define <vscale x 2 x i64> @sdot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
534+
; CHECK-LABEL: sdot_different_types:
535+
; CHECK: // %bb.0: // %entry
536+
; CHECK-NEXT: ptrue p0.h
537+
; CHECK-NEXT: sunpklo z3.s, z1.h
538+
; CHECK-NEXT: sunpkhi z1.s, z1.h
539+
; CHECK-NEXT: sxtb z2.h, p0/m, z2.h
540+
; CHECK-NEXT: ptrue p0.d
541+
; CHECK-NEXT: sunpklo z5.d, z3.s
542+
; CHECK-NEXT: sunpkhi z3.d, z3.s
543+
; CHECK-NEXT: sunpklo z7.d, z1.s
544+
; CHECK-NEXT: sunpklo z4.s, z2.h
545+
; CHECK-NEXT: sunpkhi z2.s, z2.h
546+
; CHECK-NEXT: sunpkhi z1.d, z1.s
547+
; CHECK-NEXT: sunpklo z6.d, z4.s
548+
; CHECK-NEXT: sunpkhi z4.d, z4.s
549+
; CHECK-NEXT: sunpklo z24.d, z2.s
550+
; CHECK-NEXT: sunpkhi z2.d, z2.s
551+
; CHECK-NEXT: mul z3.d, z3.d, z4.d
552+
; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d
553+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
554+
; CHECK-NEXT: movprfx z1, z3
555+
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
556+
; CHECK-NEXT: add z0.d, z1.d, z0.d
557+
; CHECK-NEXT: ret
558+
entry:
559+
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
560+
%b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i64>
561+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
562+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
563+
ret <vscale x 2 x i64> %partial.reduce
564+
}
565+
566+
define <vscale x 2 x i64> @usdot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
567+
; CHECK-LABEL: usdot_different_types:
568+
; CHECK: // %bb.0: // %entry
569+
; CHECK-NEXT: ptrue p0.h
570+
; CHECK-NEXT: uunpklo z3.s, z1.h
571+
; CHECK-NEXT: uunpkhi z1.s, z1.h
572+
; CHECK-NEXT: sxtb z2.h, p0/m, z2.h
573+
; CHECK-NEXT: ptrue p0.d
574+
; CHECK-NEXT: uunpklo z5.d, z3.s
575+
; CHECK-NEXT: uunpkhi z3.d, z3.s
576+
; CHECK-NEXT: uunpklo z7.d, z1.s
577+
; CHECK-NEXT: sunpklo z4.s, z2.h
578+
; CHECK-NEXT: sunpkhi z2.s, z2.h
579+
; CHECK-NEXT: uunpkhi z1.d, z1.s
580+
; CHECK-NEXT: sunpklo z6.d, z4.s
581+
; CHECK-NEXT: sunpkhi z4.d, z4.s
582+
; CHECK-NEXT: sunpklo z24.d, z2.s
583+
; CHECK-NEXT: sunpkhi z2.d, z2.s
584+
; CHECK-NEXT: mul z3.d, z3.d, z4.d
585+
; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d
586+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
587+
; CHECK-NEXT: movprfx z1, z3
588+
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
589+
; CHECK-NEXT: add z0.d, z1.d, z0.d
590+
; CHECK-NEXT: ret
591+
entry:
592+
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
593+
%b.wide = sext <vscale x 8 x i8> %b to <vscale x 8 x i64>
594+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
595+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
596+
ret <vscale x 2 x i64> %partial.reduce
597+
}
598+
599+
define <vscale x 2 x i64> @sudot_different_types(<vscale x 2 x i64> %acc, <vscale x 8 x i16> %a, <vscale x 8 x i8> %b){
600+
; CHECK-LABEL: sudot_different_types:
601+
; CHECK: // %bb.0: // %entry
602+
; CHECK-NEXT: and z2.h, z2.h, #0xff
603+
; CHECK-NEXT: sunpklo z3.s, z1.h
604+
; CHECK-NEXT: sunpkhi z1.s, z1.h
605+
; CHECK-NEXT: ptrue p0.d
606+
; CHECK-NEXT: uunpklo z4.s, z2.h
607+
; CHECK-NEXT: uunpkhi z2.s, z2.h
608+
; CHECK-NEXT: sunpklo z5.d, z3.s
609+
; CHECK-NEXT: sunpkhi z3.d, z3.s
610+
; CHECK-NEXT: sunpklo z7.d, z1.s
611+
; CHECK-NEXT: sunpkhi z1.d, z1.s
612+
; CHECK-NEXT: uunpklo z6.d, z4.s
613+
; CHECK-NEXT: uunpkhi z4.d, z4.s
614+
; CHECK-NEXT: uunpklo z24.d, z2.s
615+
; CHECK-NEXT: uunpkhi z2.d, z2.s
616+
; CHECK-NEXT: mul z3.d, z3.d, z4.d
617+
; CHECK-NEXT: mla z0.d, p0/m, z5.d, z6.d
618+
; CHECK-NEXT: mla z0.d, p0/m, z1.d, z2.d
619+
; CHECK-NEXT: movprfx z1, z3
620+
; CHECK-NEXT: mla z1.d, p0/m, z7.d, z24.d
621+
; CHECK-NEXT: add z0.d, z1.d, z0.d
622+
; CHECK-NEXT: ret
623+
entry:
624+
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
625+
%b.wide = zext <vscale x 8 x i8> %b to <vscale x 8 x i64>
626+
%mult = mul nuw nsw <vscale x 8 x i64> %a.wide, %b.wide
627+
%partial.reduce = tail call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv8i64(<vscale x 2 x i64> %acc, <vscale x 8 x i64> %mult)
628+
ret <vscale x 2 x i64> %partial.reduce
629+
}

0 commit comments

Comments
 (0)