@@ -16619,6 +16619,25 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
16619
16619
}
16620
16620
} // End anonymous namespace.
16621
16621
16622
+ static SDValue simplifyOp_VL(SDNode *N) {
16623
+ // TODO: Extend this to other binops using generic identity logic
16624
+ assert(N->getOpcode() == RISCVISD::ADD_VL);
16625
+ SDValue A = N->getOperand(0);
16626
+ SDValue B = N->getOperand(1);
16627
+ SDValue Passthru = N->getOperand(2);
16628
+ if (!Passthru.isUndef())
16629
+ // TODO:This could be a vmerge instead
16630
+ return SDValue();
16631
+ ;
16632
+ if (ISD::isConstantSplatVectorAllZeros(B.getNode()))
16633
+ return A;
16634
+ // Peek through fixed to scalable
16635
+ if (B.getOpcode() == ISD::INSERT_SUBVECTOR && B.getOperand(0).isUndef() &&
16636
+ ISD::isConstantSplatVectorAllZeros(B.getOperand(1).getNode()))
16637
+ return A;
16638
+ return SDValue();
16639
+ }
16640
+
16622
16641
/// Combine a binary or FMA operation to its equivalent VW or VW_W form.
16623
16642
/// The supported combines are:
16624
16643
/// add | add_vl | or disjoint | or_vl disjoint -> vwadd(u) | vwadd(u)_w
@@ -18515,20 +18534,10 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
18515
18534
return SDValue();
18516
18535
18517
18536
SDValue AccumOp = DotOp.getOperand(2);
18518
- bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
18519
- // Peek through fixed to scalable
18520
- if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
18521
- AccumOp.getOperand(0).isUndef())
18522
- IsNullAdd =
18523
- ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
18524
-
18525
18537
SDLoc DL(N);
18526
18538
EVT VT = N->getValueType(0);
18527
- // The manual constant folding is required, this case is not constant folded
18528
- // or combined.
18529
- if (!IsNullAdd)
18530
- Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
18531
- DAG.getUNDEF(VT), AddMask, AddVL);
18539
+ Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, Addend, AccumOp,
18540
+ DAG.getUNDEF(VT), AddMask, AddVL);
18532
18541
18533
18542
SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
18534
18543
DotOp.getOperand(3), DotOp->getOperand(4)};
@@ -19657,6 +19666,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
19657
19666
break;
19658
19667
}
19659
19668
case RISCVISD::ADD_VL:
19669
+ if (SDValue V = simplifyOp_VL(N))
19670
+ return V;
19660
19671
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
19661
19672
return V;
19662
19673
if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
0 commit comments