@@ -18177,17 +18177,20 @@ static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18177
18177
assert(VT == Op1.getSimpleValueType() &&
18178
18178
VT.getVectorElementType() == MVT::i32);
18179
18179
18180
- assert(VT.isFixedLengthVector());
18181
- MVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18182
- SDValue Passthru = convertToScalableVector(
18183
- ContainerVT, DAG.getConstant(0, DL, VT), DAG, Subtarget);
18184
- Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18185
- Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18186
-
18180
+ SDValue Passthru = DAG.getConstant(0, DL, VT);
18181
+ MVT ContainerVT = VT;
18182
+ if (VT.isFixedLengthVector()) {
18183
+ ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18184
+ Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
18185
+ Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18186
+ Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18187
+ }
18187
18188
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18188
18189
SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18189
18190
{Op0, Op1, Passthru, Mask, VL});
18190
- return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18191
+ if (VT.isFixedLengthVector())
18192
+ return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18193
+ return LocalAccum;
18191
18194
}
18192
18195
18193
18196
static MVT getQDOTXResultType(MVT OpVT) {
@@ -18207,7 +18210,7 @@ static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
18207
18210
EVT AVT = A.getValueType();
18208
18211
EVT BVT = B.getValueType();
18209
18212
assert(AVT.getVectorElementType() == BVT.getVectorElementType());
18210
- if (AVT.getVectorNumElements () > BVT.getVectorNumElements ()) {
18213
+ if (AVT.getVectorMinNumElements () > BVT.getVectorMinNumElements ()) {
18211
18214
std::swap(A, B);
18212
18215
std::swap(AVT, BVT);
18213
18216
}
@@ -18641,17 +18644,19 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
18641
18644
static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
18642
18645
const RISCVSubtarget &Subtarget) {
18643
18646
18644
- assert(N->getOpcode() == RISCVISD::ADD_VL);
18647
+ assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD );
18645
18648
18646
18649
if (!N->getValueType(0).isVector())
18647
18650
return SDValue();
18648
18651
18649
18652
SDValue Addend = N->getOperand(0);
18650
18653
SDValue DotOp = N->getOperand(1);
18651
18654
18652
- SDValue AddPassthruOp = N->getOperand(2);
18653
- if (!AddPassthruOp.isUndef())
18654
- return SDValue();
18655
+ if (N->getOpcode() == RISCVISD::ADD_VL) {
18656
+ SDValue AddPassthruOp = N->getOperand(2);
18657
+ if (!AddPassthruOp.isUndef())
18658
+ return SDValue();
18659
+ }
18655
18660
18656
18661
auto IsVqdotqOpc = [](unsigned Opc) {
18657
18662
switch (Opc) {
@@ -18670,8 +18675,15 @@ static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
18670
18675
if (!IsVqdotqOpc(DotOp.getOpcode()))
18671
18676
return SDValue();
18672
18677
18673
- SDValue AddMask = N->getOperand(3);
18674
- SDValue AddVL = N->getOperand(4);
18678
+ auto [AddMask, AddVL] = [](SDNode *N, SelectionDAG &DAG,
18679
+ const RISCVSubtarget &Subtarget) {
18680
+ if (N->getOpcode() == ISD::ADD) {
18681
+ SDLoc DL(N);
18682
+ return getDefaultScalableVLOps(N->getSimpleValueType(0), DL, DAG,
18683
+ Subtarget);
18684
+ }
18685
+ return std::make_pair(N->getOperand(3), N->getOperand(4));
18686
+ }(N, DAG, Subtarget);
18675
18687
18676
18688
SDValue MulVL = DotOp.getOperand(4);
18677
18689
if (AddVL != MulVL)
@@ -19309,6 +19321,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
19309
19321
return V;
19310
19322
if (SDValue V = combineToVWMACC(N, DAG, Subtarget))
19311
19323
return V;
19324
+ if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
19325
+ return V;
19312
19326
return performADDCombine(N, DCI, Subtarget);
19313
19327
}
19314
19328
case ISD::SUB: {
0 commit comments