Skip to content

Commit 68995b1

Browse files
authored
[RISCV] Support scalable vectors for the zvqdotq lowering paths (#140922)
This was an oversight in the original patch series. Without this change, the newly added tests fail assertions.
1 parent f65b35d commit 68995b1

File tree

2 files changed

+610
-15
lines changed

2 files changed

+610
-15
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18177,17 +18177,20 @@ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
1817718177
assert(VT == Op1.getSimpleValueType() &&
1817818178
VT.getVectorElementType() == MVT::i32);
1817918179

18180-
assert(VT.isFixedLengthVector());
18181-
MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18182-
SDValue Passthru = convertToScalableVector(
18183-
ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
18184-
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18185-
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18186-
18180+
SDValue Passthru = DAG.getConstant(0, DL, VT);
18181+
MVT ContainerVT = VT;
18182+
if (VT.isFixedLengthVector()) {
18183+
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18184+
Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
18185+
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18186+
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18187+
}
1818718188
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
1818818189
SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
1818918190
{Op0, Op1, Passthru, Mask, VL});
18190-
return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18191+
if (VT.isFixedLengthVector())
18192+
return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18193+
return LocalAccum;
1819118194
}
1819218195

1819318196
static MVT getQDOTXResultType(MVT OpVT) {
@@ -18207,7 +18210,7 @@ static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
1820718210
EVT AVT = A.getValueType();
1820818211
EVT BVT = B.getValueType();
1820918212
assert(AVT.getVectorElementType() == BVT.getVectorElementType());
18210-
if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
18213+
if (AVT.getVectorMinNumElements() > BVT.getVectorMinNumElements()) {
1821118214
std::swap(A, B);
1821218215
std::swap(AVT, BVT);
1821318216
}
@@ -18641,17 +18644,19 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
1864118644
static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
1864218645
const RISCVSubtarget &Subtarget) {
1864318646

18644-
assert(N->getOpcode() == RISCVISD::ADD_VL);
18647+
assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
1864518648

1864618649
if (!N->getValueType(0).isVector())
1864718650
return SDValue();
1864818651

1864918652
SDValue Addend = N->getOperand(0);
1865018653
SDValue DotOp = N->getOperand(1);
1865118654

18652-
SDValue AddPassthruOp = N->getOperand(2);
18653-
if (!AddPassthruOp.isUndef())
18654-
return SDValue();
18655+
if (N->getOpcode() == RISCVISD::ADD_VL) {
18656+
SDValue AddPassthruOp = N->getOperand(2);
18657+
if (!AddPassthruOp.isUndef())
18658+
return SDValue();
18659+
}
1865518660

1865618661
auto IsVqdotqOpc = [](unsigned Opc) {
1865718662
switch (Opc) {
@@ -18670,8 +18675,15 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
1867018675
if (!IsVqdotqOpc(DotOp.getOpcode()))
1867118676
return SDValue();
1867218677

18673-
SDValue AddMask = N->getOperand(3);
18674-
SDValue AddVL = N->getOperand(4);
18678+
auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
18679+
const RISCVSubtarget &Subtarget) {
18680+
if (N->getOpcode() == ISD::ADD) {
18681+
SDLoc DL(N);
18682+
return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
18683+
Subtarget);
18684+
}
18685+
return std::make_pair(N->getOperand(3), N->getOperand(4));
18686+
}(N, DAG, Subtarget);
1867518687

1867618688
SDValue MulVL = DotOp.getOperand(4);
1867718689
if (AddVL != MulVL)
@@ -19309,6 +19321,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1930919321
return V;
1931019322
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
1931119323
return V;
19324+
if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
19325+
return V;
1931219326
return performADDCombine(N, DCI, Subtarget);
1931319327
}
1931419328
case ISD::SUB: {

0 commit comments

Comments
 (0)