@@ -618,6 +618,8 @@ namespace {
618
618
SDValue CombineConsecutiveLoads(SDNode *N, EVT VT);
619
619
SDValue foldBitcastedFPLogic(SDNode *N, SelectionDAG &DAG,
620
620
const TargetLowering &TLI);
621
+ SDValue foldPartialReduceMLAMulOp(SDNode *N);
622
+ SDValue foldPartialReduceMLANoMulOp(SDNode *N);
621
623
622
624
SDValue CombineExtLoad(SDNode *N);
623
625
SDValue CombineZExtLogicopShiftLoad(SDNode *N);
@@ -12612,13 +12614,21 @@ SDValue DAGCombiner::visitMHISTOGRAM(SDNode *N) {
12612
12614
return SDValue();
12613
12615
}
12614
12616
12617
+ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12618
+ if (SDValue Res = foldPartialReduceMLAMulOp(N))
12619
+ return Res;
12620
+ if (SDValue Res = foldPartialReduceMLANoMulOp(N))
12621
+ return Res;
12622
+ return SDValue();
12623
+ }
12624
+
12615
12625
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(ZEXT(LHSExtOp), ZEXT(RHSExtOp)),
12616
12626
// Splat(1)) into
12617
12627
// PARTIAL_REDUCE_UMLA(Acc, LHSExtOp, RHSExtOp).
12618
12628
// Makes PARTIAL_REDUCE_*MLA(Acc, MUL(SEXT(LHSExtOp), SEXT(RHSExtOp)),
12619
12629
// Splat(1)) into
12620
12630
// PARTIAL_REDUCE_SMLA(Acc, LHSExtOp, RHSExtOp).
12621
- SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA (SDNode *N) {
12631
+ SDValue DAGCombiner::foldPartialReduceMLAMulOp (SDNode *N) {
12622
12632
SDLoc DL(N);
12623
12633
12624
12634
SDValue Acc = N->getOperand(0);
@@ -12669,6 +12679,48 @@ SDValue DAGCombiner::visitPARTIAL_REDUCE_MLA(SDNode *N) {
12669
12679
RHSExtOp);
12670
12680
}
12671
12681
12682
+ // Makes PARTIAL_REDUCE_*MLA(Acc, ZEXT(UnextOp1), Splat(1)) into
12683
+ // PARTIAL_REDUCE_UMLA(Acc, Op, TRUNC(Splat(1)))
12684
+ // Makes PARTIAL_REDUCE_*MLA(Acc, SEXT(UnextOp1), Splat(1)) into
12685
+ // PARTIAL_REDUCE_SMLA(Acc, Op, TRUNC(Splat(1)))
12686
+ SDValue DAGCombiner::foldPartialReduceMLANoMulOp(SDNode *N) {
12687
+ SDLoc DL(N);
12688
+ SDValue Acc = N->getOperand(0);
12689
+ SDValue Op1 = N->getOperand(1);
12690
+ SDValue Op2 = N->getOperand(2);
12691
+
12692
+ APInt ConstantOne;
12693
+ if (!ISD::isConstantSplatVector(Op2.getNode(), ConstantOne) ||
12694
+ !ConstantOne.isOne())
12695
+ return SDValue();
12696
+
12697
+ unsigned Op1Opcode = Op1.getOpcode();
12698
+ if (!ISD::isExtOpcode(Op1Opcode))
12699
+ return SDValue();
12700
+
12701
+ SDValue UnextOp1 = Op1.getOperand(0);
12702
+ EVT UnextOp1VT = UnextOp1.getValueType();
12703
+
12704
+ if (!TLI.isPartialReduceMLALegalOrCustom(N->getValueType(0), UnextOp1VT))
12705
+ return SDValue();
12706
+
12707
+ SDValue TruncOp2 = DAG.getNode(ISD::TRUNCATE, DL, UnextOp1VT, Op2);
12708
+
12709
+ bool Op1IsSigned = Op1Opcode == ISD::SIGN_EXTEND;
12710
+
12711
+ bool NodeIsSigned = N->getOpcode() == ISD::PARTIAL_REDUCE_SMLA;
12712
+ EVT AccElemVT = Acc.getValueType().getVectorElementType();
12713
+ if (Op1IsSigned != NodeIsSigned &&
12714
+ (Op1.getValueType().getVectorElementType() != AccElemVT ||
12715
+ Op2.getValueType().getVectorElementType() != AccElemVT))
12716
+ return SDValue();
12717
+
12718
+ unsigned NewOpcode =
12719
+ Op1IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
12720
+ return DAG.getNode(NewOpcode, DL, N->getValueType(0), Acc, UnextOp1,
12721
+ TruncOp2);
12722
+ }
12723
+
12672
12724
SDValue DAGCombiner::visitVP_STRIDED_LOAD(SDNode *N) {
12673
12725
auto *SLD = cast<VPStridedLoadSDNode>(N);
12674
12726
EVT EltVT = SLD->getValueType(0).getVectorElementType();
0 commit comments