Skip to content

Commit 6408291

Browse files
authored
[RISCV] Fold add_vl into accumulator operand of vqdot* (#139484)
If we have a add_vl following a vqdot* instruction, we can move the add before the vqdot instead. For cases where the prior accumulator was zero, we can fold the add into the vqdot* instruction entirely. This directly parallels the folding we do for multiply add variants.
1 parent 045fdda commit 6408291

File tree

2 files changed

+82
-21
lines changed

2 files changed

+82
-21
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18408,7 +18408,6 @@ static SDValue performVECTOR_SHUFFLECombine(SDNode *N, SelectionDAG &DAG,
1840818408

1840918409
static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
1841018410
const RISCVSubtarget &Subtarget) {
18411-
1841218411
assert(N->getOpcode() == RISCVISD::ADD_VL || N->getOpcode() == ISD::ADD);
1841318412

1841418413
if (N->getValueType(0).isFixedLengthVector())
@@ -18472,9 +18471,74 @@ static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
1847218471
return DAG.getNode(Opc, DL, VT, Ops);
1847318472
}
1847418473

18475-
static bool legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
18476-
ISD::MemIndexType &IndexType,
18477-
RISCVTargetLowering::DAGCombinerInfo &DCI) {
18474+
static SDValue combineVqdotAccum(SDNode *N, SelectionDAG &DAG,
18475+
const RISCVSubtarget &Subtarget) {
18476+
18477+
assert(N->getOpcode() == RISCVISD::ADD_VL);
18478+
18479+
if (!N->getValueType(0).isVector())
18480+
return SDValue();
18481+
18482+
SDValue Addend = N->getOperand(0);
18483+
SDValue DotOp = N->getOperand(1);
18484+
18485+
SDValue AddPassthruOp = N->getOperand(2);
18486+
if (!AddPassthruOp.isUndef())
18487+
return SDValue();
18488+
18489+
auto IsVqdotqOpc = [](unsigned Opc) {
18490+
switch (Opc) {
18491+
case RISCVISD::VQDOT_VL:
18492+
case RISCVISD::VQDOTU_VL:
18493+
case RISCVISD::VQDOTSU_VL:
18494+
return true;
18495+
default:
18496+
return false;
18497+
}
18498+
};
18499+
18500+
if (!IsVqdotqOpc(DotOp.getOpcode()))
18501+
std::swap(Addend, DotOp);
18502+
18503+
if (!IsVqdotqOpc(DotOp.getOpcode()))
18504+
return SDValue();
18505+
18506+
SDValue AddMask = N->getOperand(3);
18507+
SDValue AddVL = N->getOperand(4);
18508+
18509+
SDValue MulVL = DotOp.getOperand(4);
18510+
if (AddVL != MulVL)
18511+
return SDValue();
18512+
18513+
if (AddMask.getOpcode() != RISCVISD::VMSET_VL ||
18514+
AddMask.getOperand(0) != MulVL)
18515+
return SDValue();
18516+
18517+
SDValue AccumOp = DotOp.getOperand(2);
18518+
bool IsNullAdd = ISD::isConstantSplatVectorAllZeros(AccumOp.getNode());
18519+
// Peek through fixed to scalable
18520+
if (!IsNullAdd && AccumOp.getOpcode() == ISD::INSERT_SUBVECTOR &&
18521+
AccumOp.getOperand(0).isUndef())
18522+
IsNullAdd =
18523+
ISD::isConstantSplatVectorAllZeros(AccumOp.getOperand(1).getNode());
18524+
18525+
SDLoc DL(N);
18526+
EVT VT = N->getValueType(0);
18527+
// The manual constant folding is required, this case is not constant folded
18528+
// or combined.
18529+
if (!IsNullAdd)
18530+
Addend = DAG.getNode(RISCVISD::ADD_VL, DL, VT, AccumOp, Addend,
18531+
DAG.getUNDEF(VT), AddMask, AddVL);
18532+
18533+
SDValue Ops[] = {DotOp.getOperand(0), DotOp.getOperand(1), Addend,
18534+
DotOp.getOperand(3), DotOp->getOperand(4)};
18535+
return DAG.getNode(DotOp->getOpcode(), DL, VT, Ops);
18536+
}
18537+
18538+
static bool
18539+
legalizeScatterGatherIndexType(SDLoc DL, SDValue &Index,
18540+
ISD::MemIndexType &IndexType,
18541+
RISCVTargetLowering::DAGCombinerInfo &DCI) {
1847818542
if (!DCI.isBeforeLegalize())
1847918543
return false;
1848018544

@@ -19595,6 +19659,8 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
1959519659
case RISCVISD::ADD_VL:
1959619660
if (SDValue V = combineOp_VLToVWOp_VL(N, DCI, Subtarget))
1959719661
return V;
19662+
if (SDValue V = combineVqdotAccum(N, DAG, Subtarget))
19663+
return V;
1959819664
return combineToVWMACC(N, DAG, Subtarget);
1959919665
case RISCVISD::VWADD_W_VL:
1960019666
case RISCVISD::VWADDU_W_VL:

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -314,11 +314,10 @@ define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
314314
; DOT-LABEL: vqdot_vv_accum:
315315
; DOT: # %bb.0: # %entry
316316
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
317-
; DOT-NEXT: vmv.v.i v10, 0
318-
; DOT-NEXT: vqdot.vv v10, v8, v9
319-
; DOT-NEXT: vadd.vv v8, v10, v12
317+
; DOT-NEXT: vmv1r.v v16, v12
318+
; DOT-NEXT: vqdot.vv v16, v8, v9
320319
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
321-
; DOT-NEXT: vmv.v.v v12, v8
320+
; DOT-NEXT: vmv.v.v v12, v16
322321
; DOT-NEXT: vmv.s.x v8, zero
323322
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
324323
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -349,11 +348,10 @@ define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
349348
; DOT-LABEL: vqdotu_vv_accum:
350349
; DOT: # %bb.0: # %entry
351350
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
352-
; DOT-NEXT: vmv.v.i v10, 0
353-
; DOT-NEXT: vqdotu.vv v10, v8, v9
354-
; DOT-NEXT: vadd.vv v8, v10, v12
351+
; DOT-NEXT: vmv1r.v v16, v12
352+
; DOT-NEXT: vqdotu.vv v16, v8, v9
355353
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
356-
; DOT-NEXT: vmv.v.v v12, v8
354+
; DOT-NEXT: vmv.v.v v12, v16
357355
; DOT-NEXT: vmv.s.x v8, zero
358356
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
359357
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -384,11 +382,10 @@ define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
384382
; DOT-LABEL: vqdotsu_vv_accum:
385383
; DOT: # %bb.0: # %entry
386384
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
387-
; DOT-NEXT: vmv.v.i v10, 0
388-
; DOT-NEXT: vqdotsu.vv v10, v8, v9
389-
; DOT-NEXT: vadd.vv v8, v10, v12
385+
; DOT-NEXT: vmv1r.v v16, v12
386+
; DOT-NEXT: vqdotsu.vv v16, v8, v9
390387
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
391-
; DOT-NEXT: vmv.v.v v12, v8
388+
; DOT-NEXT: vmv.v.v v12, v16
392389
; DOT-NEXT: vmv.s.x v8, zero
393390
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
394391
; DOT-NEXT: vredsum.vs v8, v12, v8
@@ -516,12 +513,10 @@ define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %
516513
; DOT: # %bb.0: # %entry
517514
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
518515
; DOT-NEXT: vmv.v.i v12, 0
519-
; DOT-NEXT: vmv.v.i v13, 0
520516
; DOT-NEXT: vqdot.vv v12, v8, v9
521-
; DOT-NEXT: vqdot.vv v13, v10, v11
522-
; DOT-NEXT: vadd.vv v8, v12, v13
523-
; DOT-NEXT: vmv.s.x v9, zero
524-
; DOT-NEXT: vredsum.vs v8, v8, v9
517+
; DOT-NEXT: vqdot.vv v12, v10, v11
518+
; DOT-NEXT: vmv.s.x v8, zero
519+
; DOT-NEXT: vredsum.vs v8, v12, v8
525520
; DOT-NEXT: vmv.x.s a0, v8
526521
; DOT-NEXT: ret
527522
entry:

0 commit comments

Comments
 (0)