@@ -18372,31 +18372,6 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
18372
18372
DAG.getBuildVector(VT, DL, RHSOps));
18373
18373
}
18374
18374
18375
- static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18376
- const SDLoc &DL, SelectionDAG &DAG,
18377
- const RISCVSubtarget &Subtarget) {
18378
- assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
18379
- RISCVISD::VQDOTSU_VL == Opc);
18380
- MVT VT = Op0.getSimpleValueType();
18381
- assert(VT == Op1.getSimpleValueType() &&
18382
- VT.getVectorElementType() == MVT::i32);
18383
-
18384
- SDValue Passthru = DAG.getConstant(0, DL, VT);
18385
- MVT ContainerVT = VT;
18386
- if (VT.isFixedLengthVector()) {
18387
- ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18388
- Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
18389
- Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18390
- Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18391
- }
18392
- auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18393
- SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18394
- {Op0, Op1, Passthru, Mask, VL});
18395
- if (VT.isFixedLengthVector())
18396
- return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18397
- return LocalAccum;
18398
- }
18399
-
18400
18375
static MVT getQDOTXResultType(MVT OpVT) {
18401
18376
ElementCount OpEC = OpVT.getVectorElementCount();
18402
18377
assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
@@ -18455,61 +18430,62 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
18455
18430
}
18456
18431
}
18457
18432
18458
- // reduce ( zext a) <--> reduce (mul zext a. zext 1)
18459
- // reduce ( sext a) <--> reduce (mul sext a. sext 1)
18433
+ // zext a <--> partial_reduce_umla 0, a, 1
18434
+ // sext a <--> partial_reduce_smla 0, a, 1
18460
18435
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
18461
18436
InVec.getOpcode() == ISD::SIGN_EXTEND) {
18462
18437
SDValue A = InVec.getOperand(0);
18463
- if ( A.getValueType().getVectorElementType() != MVT::i8 ||
18464
- !TLI.isTypeLegal(A.getValueType() ))
18438
+ EVT OpVT = A.getValueType();
18439
+ if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT ))
18465
18440
return SDValue();
18466
18441
18467
18442
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
18468
- A = DAG.getBitcast(ResVT, A);
18469
- SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
18470
-
18443
+ SDValue B = DAG.getConstant(0x1, DL, OpVT);
18471
18444
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
18472
- unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
18473
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18445
+ unsigned Opc =
18446
+ IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
18447
+ return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
18474
18448
}
18475
18449
18476
- // mul (sext, sext) -> vqdot
18477
- // mul (zext, zext) -> vqdotu
18478
- // mul (sext, zext) -> vqdotsu
18479
- // mul (zext, sext) -> vqdotsu (swapped)
18480
- // TODO: Improve .vx handling - we end up with a sub-vector insert
18481
- // which confuses the splat pattern matching. Also, match vqdotus.vx
18450
+ // mul (sext a, sext b) -> partial_reduce_smla 0, a, b
18451
+ // mul (zext a, zext b) -> partial_reduce_umla 0, a, b
18452
+ // mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
18453
+ // mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
18482
18454
if (InVec.getOpcode() != ISD::MUL)
18483
18455
return SDValue();
18484
18456
18485
18457
SDValue A = InVec.getOperand(0);
18486
18458
SDValue B = InVec.getOperand(1);
18487
- unsigned Opc = 0;
18488
- if (A.getOpcode() == B.getOpcode()) {
18489
- if (A.getOpcode() == ISD::SIGN_EXTEND)
18490
- Opc = RISCVISD::VQDOT_VL;
18491
- else if (A.getOpcode() == ISD::ZERO_EXTEND)
18492
- Opc = RISCVISD::VQDOTU_VL;
18493
- else
18494
- return SDValue();
18495
- } else {
18496
- if (B.getOpcode() != ISD::ZERO_EXTEND)
18497
- std::swap(A, B);
18498
- if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
18499
- return SDValue();
18500
- Opc = RISCVISD::VQDOTSU_VL;
18501
- }
18502
- assert(Opc);
18503
18459
18504
- if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
18505
- A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
18460
+ if (!ISD::isExtOpcode(A.getOpcode()))
18461
+ return SDValue();
18462
+
18463
+ EVT OpVT = A.getOperand(0).getValueType();
18464
+ if (OpVT.getVectorElementType() != MVT::i8 ||
18465
+ OpVT != B.getOperand(0).getValueType() ||
18506
18466
!TLI.isTypeLegal(A.getValueType()))
18507
18467
return SDValue();
18508
18468
18509
- MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
18510
- A = DAG.getBitcast(ResVT, A.getOperand(0));
18511
- B = DAG.getBitcast(ResVT, B.getOperand(0));
18512
- return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18469
+ unsigned Opc;
18470
+ if (A.getOpcode() == ISD::SIGN_EXTEND && B.getOpcode() == ISD::SIGN_EXTEND)
18471
+ Opc = ISD::PARTIAL_REDUCE_SMLA;
18472
+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
18473
+ B.getOpcode() == ISD::ZERO_EXTEND)
18474
+ Opc = ISD::PARTIAL_REDUCE_UMLA;
18475
+ else if (A.getOpcode() == ISD::SIGN_EXTEND &&
18476
+ B.getOpcode() == ISD::ZERO_EXTEND)
18477
+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
18478
+ else if (A.getOpcode() == ISD::ZERO_EXTEND &&
18479
+ B.getOpcode() == ISD::SIGN_EXTEND) {
18480
+ Opc = ISD::PARTIAL_REDUCE_SUMLA;
18481
+ std::swap(A, B);
18482
+ } else
18483
+ return SDValue();
18484
+
18485
+ MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
18486
+ return DAG.getNode(
18487
+ Opc, DL, ResVT,
18488
+ {DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
18513
18489
}
18514
18490
18515
18491
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,
0 commit comments