Skip to content

Commit 1ffdf53

Browse files
authored
[RISCV] Introduce add_vl combine for identity operand (#139742)
This is mostly a refactor of the recently added zvqdotq accumulation path so that I can try merging that with the vwmacc codepaths.
1 parent 91ea494 commit 1ffdf53

File tree

1 file changed

+23
-12
lines changed

1 file changed

+23
-12
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16619,6 +16619,25 @@ NodeExtensionHelper::getSupportedFoldings(const SDNode *Root) {
1661916619
}
1662016620
} // End anonymous namespace.
1662116621

16622+
static SDValue simplifyOp_VL(SDNode *N) {
16623+
// TODO: Extend this to other binops using generic identity logic
16624+
assert(N->getOpcode() == RISCVISD::ADD_VL);
16625+
SDValue A = N->getOperand(0);
16626+
SDValue B = N->getOperand(1);
16627+
SDValue Passthru = N->getOperand(2);
16628+
if (!Passthru.isUndef())
16629+
// TODO:This could be a vmerge instead
16630+
return SDValue();
16631+
;
16632+
if (ISD::isConstantSplatVectorAllZeros(B.getNode()))
16633+
return A;
16634+
// Peek through fixed to scalable
16635+
if (B.getOpcode() == ISD::INSERT_SUBVECTOR && B.getOperand(0).isUndef() &&
16636+
ISD::isConstantSplatVectorAllZeros(B.getOperand(1).getNode()))
16637+
return A;
16638+
return SDValue();
16639+
}
16640+
1662216641
/// Combine a binary or FMA operation to its equivalent VW or VW_W form.
1662316642
/// The supported combines are:
1662416643
/// add | add_vl | or disjoint | or_vl disjoint -> vwadd(u) | vwadd(u)_w
@@ -18515,20 +18534,10 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
1851518534
return SDValue();
1851618535

1851718536
SDValue AccumOp = DotOp.getOperand(2);
18518-
bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
18519-
// Peek through fixed to scalable
18520-
if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
18521-
AccumOp.getOperand(0).isUndef())
18522-
IsNullAdd =
18523-
ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
18524-
1852518537
SDLoc DL(N);
1852618538
EVT VT = N->getValueType(0);
18527-
// The manual constant folding is required, this case is not constant folded
18528-
// or combined.
18529-
if (!IsNullAdd)
18530-
Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
18531-
DAG.getUNDEF(VT), AddMask, AddVL);
18539+
Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, Addend, AccumOp,
18540+
DAG.getUNDEF(VT), AddMask, AddVL);
1853218541

1853318542
SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
1853418543
DotOp.getOperand(3), DotOp->getOperand(4)};
@@ -19657,6 +19666,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1965719666
break;
1965819667
}
1965919668
case RISCVISD::ADD_VL:
19669+
if (SDValue V = simplifyOp_VL(N))
19670+
return V;
1966019671
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1966119672
return V;
1966219673
if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))

0 commit comments

Comments
 (0)