Skip to content

Commit adf8348

Browse files
NickGuy-ArmJamesChesterman
authored andcommitted
[DAGCombiner] Add DAG combine for PARTIAL_REDUCE_MLA when no mul op (llvm#131326)
Generic DAG combine for ISD::PARTIAL_REDUCE_U/SMLA to convert: PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_UMLA(Acc, UnextOp1, TRUNC(Splat(1))) and PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into PARTIAL_REDUCE_SMLA(Acc, UnextOp1, TRUNC(Splat(1))). --------- Co-authored-by: James Chesterman <[email protected]>
1 parent 9e3b128 commit adf8348

File tree

2 files changed

+56
-41
lines changed

2 files changed

+56
-41
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,8 @@ namespace {
618618
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
619619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620620
const TargetLowering &TLI);
621+
SDValue foldPartialReduceMLAMulOp(SDNode *N);
622+
SDValue foldPartialReduceAdd(SDNode *N);
621623

622624
SDValue CombineExtLoad(SDNode *N);
623625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12601,12 +12603,20 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
1260112603
return SDValue();
1260212604
}
1260312605

12606+
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12607+
if (SDValue Res = foldPartialReduceMLAMulOp(N))
12608+
return Res;
12609+
if (SDValue Res = foldPartialReduceAdd(N))
12610+
return Res;
12611+
return SDValue();
12612+
}
12613+
1260412614
// partial_reduce_*mla(acc, mul(ext(a), ext(b)), splat(1))
1260512615
// -> partial_reduce_*mla(acc, a, b)
1260612616
//
1260712617
// partial_reduce_*mla(acc, mul(ext(x), splat(C)), splat(1))
1260812618
// -> partial_reduce_*mla(acc, x, C)
12609-
SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12619+
SDValue DAGCombiner::foldPartialReduceMLAMulOp(SDNode *N) {
1261012620
SDLoc DL(N);
1261112621
auto *Context = DAG.getContext();
1261212622
SDValue Acc = N->getOperand(0);
@@ -12672,6 +12682,43 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
1267212682
RHSExtOp);
1267312683
}
1267412684

12685+
// partial.reduce.umla(acc, zext(op), splat(1))
12686+
// -> partial.reduce.umla(acc, op, splat(trunc(1)))
12687+
// partial.reduce.smla(acc, sext(op), splat(1))
12688+
// -> partial.reduce.smla(acc, op, splat(trunc(1)))
12689+
SDValue DAGCombiner::foldPartialReduceAdd(SDNode *N) {
12690+
SDLoc DL(N);
12691+
SDValue Acc = N->getOperand(0);
12692+
SDValue Op1 = N->getOperand(1);
12693+
SDValue Op2 = N->getOperand(2);
12694+
12695+
APInt ConstantOne;
12696+
if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12697+
!ConstantOne.isOne())
12698+
return SDValue();
12699+
12700+
unsigned Op1Opcode = Op1.getOpcode();
12701+
if (!ISD::isExtOpcode(Op1Opcode))
12702+
return SDValue();
12703+
12704+
SDValue UnextOp1 = Op1.getOperand(0);
12705+
EVT UnextOp1VT = UnextOp1.getValueType();
12706+
if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12707+
return SDValue();
12708+
12709+
bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12710+
bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12711+
EVT AccElemVT = Acc.getValueType().getVectorElementType();
12712+
if (Op1IsSigned != NodeIsSigned &&
12713+
Op1.getValueType().getVectorElementType() != AccElemVT)
12714+
return SDValue();
12715+
12716+
unsigned NewOpcode =
12717+
Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12718+
return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12719+
DAG.getConstant(1, DL, UnextOp1VT));
12720+
}
12721+
1267512722
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
1267612723
auto *SLD = cast<VPStridedLoadSDNode>(N);
1267712724
EVT EltVT = SLD->getValueType(0).getVectorElementType();

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 8 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -516,16 +516,8 @@ define <vscale x 4 x i32> @udot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
516516
;
517517
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op:
518518
; CHECK-NEWLOWERING: // %bb.0:
519-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.h, z1.b
520-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.h, z1.b
521-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.s, z2.h
522-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z1.h
523-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
524-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
525-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
526-
; CHECK-NEWLOWERING-NEXT: add z1.s, z4.s, z3.s
527-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
528-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
519+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
520+
; CHECK-NEWLOWERING-NEXT: udot z0.s, z1.b, z2.b
529521
; CHECK-NEWLOWERING-NEXT: ret
530522
%a.ext = zext <vscale x 16 x i8> %a to <vscale x 16 x i32>
531523
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -541,16 +533,8 @@ define <vscale x 4 x i32> @sdot_no_bin_op(<vscale x 4 x i32> %acc, <vscale x 16
541533
;
542534
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op:
543535
; CHECK-NEWLOWERING: // %bb.0:
544-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.h, z1.b
545-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.h, z1.b
546-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.s, z2.h
547-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z1.h
548-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
549-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
550-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
551-
; CHECK-NEWLOWERING-NEXT: add z1.s, z4.s, z3.s
552-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z1.s
553-
; CHECK-NEWLOWERING-NEXT: add z0.s, z0.s, z2.s
536+
; CHECK-NEWLOWERING-NEXT: mov z2.b, #1 // =0x1
537+
; CHECK-NEWLOWERING-NEXT: sdot z0.s, z1.b, z2.b
554538
; CHECK-NEWLOWERING-NEXT: ret
555539
%a.ext = sext <vscale x 16 x i8> %a to <vscale x 16 x i32>
556540
%partial.reduce = tail call <vscale x 4 x i32> @llvm.experimental.vector.partial.reduce.add.nxv4i32.nxv16i32(<vscale x 4 x i32> %acc, <vscale x 16 x i32> %a.ext)
@@ -566,16 +550,8 @@ define <vscale x 2 x i64> @udot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
566550
;
567551
; CHECK-NEWLOWERING-LABEL: udot_no_bin_op_wide:
568552
; CHECK-NEWLOWERING: // %bb.0: // %entry
569-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z1.h
570-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.s, z1.h
571-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.d, z2.s
572-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z1.s
573-
; CHECK-NEWLOWERING-NEXT: uunpklo z1.d, z1.s
574-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
575-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
576-
; CHECK-NEWLOWERING-NEXT: add z1.d, z4.d, z3.d
577-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
578-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
553+
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
554+
; CHECK-NEWLOWERING-NEXT: udot z0.d, z1.h, z2.h
579555
; CHECK-NEWLOWERING-NEXT: ret
580556
entry:
581557
%a.wide = zext <vscale x 8 x i16> %a to <vscale x 8 x i64>
@@ -592,16 +568,8 @@ define <vscale x 2 x i64> @sdot_no_bin_op_wide(<vscale x 2 x i64> %acc, <vscale
592568
;
593569
; CHECK-NEWLOWERING-LABEL: sdot_no_bin_op_wide:
594570
; CHECK-NEWLOWERING: // %bb.0: // %entry
595-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z1.h
596-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.s, z1.h
597-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.d, z2.s
598-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z1.s
599-
; CHECK-NEWLOWERING-NEXT: sunpklo z1.d, z1.s
600-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
601-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
602-
; CHECK-NEWLOWERING-NEXT: add z1.d, z4.d, z3.d
603-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z1.d
604-
; CHECK-NEWLOWERING-NEXT: add z0.d, z0.d, z2.d
571+
; CHECK-NEWLOWERING-NEXT: mov z2.h, #1 // =0x1
572+
; CHECK-NEWLOWERING-NEXT: sdot z0.d, z1.h, z2.h
605573
; CHECK-NEWLOWERING-NEXT: ret
606574
entry:
607575
%a.wide = sext <vscale x 8 x i16> %a to <vscale x 8 x i64>

0 commit comments

Comments
 (0)