Skip to content

Commit 7e1b69d

Browse files
JamesChestermanNickGuy-Arm
authored andcommitted
[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op
Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))).
1 parent a1f369e commit 7e1b69d

File tree

2 files changed

+73
-83
lines changed

2 files changed

+73
-83
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ namespace {
618618
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
619619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620
const TargetLowering &TLI);
621+
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622+
SDValue foldPartialReduceMLANoMulOp(SDNode *N);
621623

622624
SDValue CombineExtLoad(SDNode *N);
623625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12612,13 +12614,21 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1261212614
return SDValue();
1261312615
}
1261412616

12617+
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12618+
if (SDValue Res = foldPartialReduceMLAMulOp(N))
12619+
return Res;
12620+
if (SDValue Res = foldPartialReduceMLANoMulOp(N))
12621+
return Res;
12622+
return SDValue();
12623+
}
12624+
1261512625
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
1261612626
// Splat(1)) into
1261712627
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
1261812628
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
1261912629
// Splat(1)) into
1262012630
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
12621-
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12631+
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1262212632
SDLoc DL(N);
1262312633

1262412634
SDValue Acc = N->getOperand(0);
@@ -12669,6 +12679,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1266912679
RHSExtOp);
1267012680
}
1267112681

12682+
// Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
12683+
// PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
12684+
// Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
12685+
// PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
12686+
SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12687+
SDLoc DL(N);
12688+
SDValue Acc = N->getOperand(0);
12689+
SDValue Op1 = N->getOperand(1);
12690+
SDValue Op2 = N->getOperand(2);
12691+
12692+
APInt ConstantOne;
12693+
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12694+
!ConstantOne.isOne())
12695+
return SDValue();
12696+
12697+
unsigned Op1Opcode = Op1.getOpcode();
12698+
if (!ISD::isExtOpcode(Op1Opcode))
12699+
return SDValue();
12700+
12701+
SDValue UnextOp1 = Op1.getOperand(0);
12702+
EVT UnextOp1VT = UnextOp1.getValueType();
12703+
12704+
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12705+
return SDValue();
12706+
12707+
SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12708+
12709+
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12710+
12711+
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12712+
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12713+
if (Op1IsSigned != NodeIsSigned &&
12714+
(Op1.getValueType().getVectorElementType() != AccElemVT ||
12715+
Op2.getValueType().getVectorElementType() != AccElemVT))
12716+
return SDValue();
12717+
12718+
unsigned NewOpcode =
12719+
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12720+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12721+
TruncOp2);
12722+
}
12723+
1267212724
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
1267312725
auto *SLD = cast<VPStridedLoadSDNode>(N);
1267412726
EVT EltVT = SLD->getValueType(0).getVectorElementType();

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 20 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -620,16 +620,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
620620
;
621621
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
622622
; CHECK-NEWLOWERING: // %bb.0:
623-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z1.b
624-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.h, z1.b
625-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
626-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z1.h
627-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
628-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
629-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
630-
; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
631-
; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
632-
; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
623+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
624+
; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
633625
; CHECK-NEWLOWERING-NEXT: ret
634626
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
635627
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -645,16 +637,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
645637
;
646638
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
647639
; CHECK-NEWLOWERING: // %bb.0:
648-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z1.b
649-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.h, z1.b
650-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
651-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z1.h
652-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
653-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
654-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z3.s
655-
; CHECK-NEWLOWERING-NEXT: add z1.s, z2.s, z1.s
656-
; CHECK-NEWLOWERING-NEXT: add z0.s, z4.s, z0.s
657-
; CHECK-NEWLOWERING-NEXT: add z0.s, z1.s, z0.s
640+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
641+
; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
658642
; CHECK-NEWLOWERING-NEXT: ret
659643
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
660644
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -670,16 +654,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
670654
;
671655
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
672656
; CHECK-NEWLOWERING: // %bb.0: // %entry
673-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z1.h
674-
; CHECK-NEWLOWERING-NEXT: uunpkhi z1.s, z1.h
675-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
676-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z1.s
677-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
678-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
679-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
680-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
681-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
682-
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
657+
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
658+
; CHECK-NEWLOWERING-NEXT: udot z0.d, z1.h, z2.h
683659
; CHECK-NEWLOWERING-NEXT: ret
684660
entry:
685661
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -696,16 +672,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
696672
;
697673
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
698674
; CHECK-NEWLOWERING: // %bb.0: // %entry
699-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z1.h
700-
; CHECK-NEWLOWERING-NEXT: sunpkhi z1.s, z1.h
701-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z2.s
702-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z1.s
703-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
704-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
705-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z3.d
706-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
707-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
708-
; CHECK-NEWLOWERING-NEXT: add z0.d, z1.d, z0.d
675+
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
676+
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z1.h, z2.h
709677
; CHECK-NEWLOWERING-NEXT: ret
710678
entry:
711679
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -727,28 +695,13 @@ define <vscale x 4 x i64> @udot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
727695
;
728696
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_8to64:
729697
; CHECK-NEWLOWERING: // %bb.0:
730-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z2.b
698+
; CHECK-NEWLOWERING-NEXT: mov z3.b, #1 // =0x1
699+
; CHECK-NEWLOWERING-NEXT: uunpklo z5.h, z2.b
731700
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z2.b
732-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.s, z3.h
733-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z2.h
734-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.s, z2.h
735-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
736-
; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z4.s
737-
; CHECK-NEWLOWERING-NEXT: uunpklo z4.d, z4.s
738-
; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z5.s
739-
; CHECK-NEWLOWERING-NEXT: uunpklo z24.d, z2.s
740-
; CHECK-NEWLOWERING-NEXT: uunpklo z25.d, z3.s
741-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
742-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
743-
; CHECK-NEWLOWERING-NEXT: uunpklo z5.d, z5.s
744-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
745-
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
746-
; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
747-
; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
748-
; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
749-
; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
750-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
751-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
701+
; CHECK-NEWLOWERING-NEXT: uunpklo z4.h, z3.b
702+
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.h, z3.b
703+
; CHECK-NEWLOWERING-NEXT: udot z0.d, z5.h, z4.h
704+
; CHECK-NEWLOWERING-NEXT: udot z1.d, z2.h, z3.h
752705
; CHECK-NEWLOWERING-NEXT: ret
753706
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
754707
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)
@@ -769,28 +722,13 @@ define <vscale x 4 x i64> @sdot_no_bin_op_8to64(<vscale x 4 x i64> %acc, <vscale
769722
;
770723
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_8to64:
771724
; CHECK-NEWLOWERING: // %bb.0:
772-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z2.b
725+
; CHECK-NEWLOWERING-NEXT: mov z3.b, #1 // =0x1
726+
; CHECK-NEWLOWERING-NEXT: sunpklo z5.h, z2.b
773727
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z2.b
774-
; CHECK-NEWLOWERING-NEXT: sunpklo z4.s, z3.h
775-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z2.h
776-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.s, z2.h
777-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
778-
; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z4.s
779-
; CHECK-NEWLOWERING-NEXT: sunpklo z4.d, z4.s
780-
; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z5.s
781-
; CHECK-NEWLOWERING-NEXT: sunpklo z24.d, z2.s
782-
; CHECK-NEWLOWERING-NEXT: sunpklo z25.d, z3.s
783-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
784-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
785-
; CHECK-NEWLOWERING-NEXT: sunpklo z5.d, z5.s
786-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z4.d
787-
; CHECK-NEWLOWERING-NEXT: add z1.d, z1.d, z6.d
788-
; CHECK-NEWLOWERING-NEXT: add z4.d, z25.d, z24.d
789-
; CHECK-NEWLOWERING-NEXT: add z2.d, z3.d, z2.d
790-
; CHECK-NEWLOWERING-NEXT: add z0.d, z5.d, z0.d
791-
; CHECK-NEWLOWERING-NEXT: add z1.d, z7.d, z1.d
792-
; CHECK-NEWLOWERING-NEXT: add z0.d, z4.d, z0.d
793-
; CHECK-NEWLOWERING-NEXT: add z1.d, z2.d, z1.d
728+
; CHECK-NEWLOWERING-NEXT: sunpklo z4.h, z3.b
729+
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.h, z3.b
730+
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z5.h, z4.h
731+
; CHECK-NEWLOWERING-NEXT: sdot z1.d, z2.h, z3.h
794732
; CHECK-NEWLOWERING-NEXT: ret
795733
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
796734
%partial.reduce = tail call <vscale x 4 x i64> @llvm.experimental.vector.partial.reduce.add.nxv4i64.nxv16i64(<vscale x 4 x i64> %acc, <vscale x 16 x i64> %a.ext)

0 commit comments

Comments
 (0)