Skip to content

Commit 2680afb

Browse files
authored
[RISCV] Migrate zvqdotq reduce matching to use partial_reduce infrastructure (#142212)
This involves a codegen regression at the moment due to the issue described in 443cdd0, but this aligns the lowering paths for this case and makes it less likely future bugs go undetected.
1 parent f5733b0 commit 2680afb

File tree

3 files changed

+58
-82
lines changed

3 files changed

+58
-82
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 38 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -18372,31 +18372,6 @@ static SDValue performBUILD_VECTORCombine(SDNode *N, SelectionDAG &DAG,
1837218372
DAG.getBuildVector(VT, DL, RHSOps));
1837318373
}
1837418374

18375-
static SDValue lowerVQDOT(unsigned Opc, SDValue Op0, SDValue Op1,
18376-
const SDLoc &DL, SelectionDAG &DAG,
18377-
const RISCVSubtarget &Subtarget) {
18378-
assert(RISCVISD::VQDOT_VL == Opc || RISCVISD::VQDOTU_VL == Opc ||
18379-
RISCVISD::VQDOTSU_VL == Opc);
18380-
MVT VT = Op0.getSimpleValueType();
18381-
assert(VT == Op1.getSimpleValueType() &&
18382-
VT.getVectorElementType() == MVT::i32);
18383-
18384-
SDValue Passthru = DAG.getConstant(0, DL, VT);
18385-
MVT ContainerVT = VT;
18386-
if (VT.isFixedLengthVector()) {
18387-
ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
18388-
Passthru = convertToScalableVector(ContainerVT, Passthru, DAG, Subtarget);
18389-
Op0 = convertToScalableVector(ContainerVT, Op0, DAG, Subtarget);
18390-
Op1 = convertToScalableVector(ContainerVT, Op1, DAG, Subtarget);
18391-
}
18392-
auto [Mask, VL] = getDefaultVLOps(VT, ContainerVT, DL, DAG, Subtarget);
18393-
SDValue LocalAccum = DAG.getNode(Opc, DL, ContainerVT,
18394-
{Op0, Op1, Passthru, Mask, VL});
18395-
if (VT.isFixedLengthVector())
18396-
return convertFromScalableVector(VT, LocalAccum, DAG, Subtarget);
18397-
return LocalAccum;
18398-
}
18399-
1840018375
static MVT getQDOTXResultType(MVT OpVT) {
1840118376
ElementCount OpEC = OpVT.getVectorElementCount();
1840218377
assert(OpEC.isKnownMultipleOf(4) && OpVT.getVectorElementType() == MVT::i8);
@@ -18455,61 +18430,62 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1845518430
}
1845618431
}
1845718432

18458-
// reduce (zext a) <--> reduce (mul zext a. zext 1)
18459-
// reduce (sext a) <--> reduce (mul sext a. sext 1)
18433+
// zext a <--> partial_reduce_umla 0, a, 1
18434+
// sext a <--> partial_reduce_smla 0, a, 1
1846018435
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||
1846118436
InVec.getOpcode() == ISD::SIGN_EXTEND) {
1846218437
SDValue A = InVec.getOperand(0);
18463-
if (A.getValueType().getVectorElementType() != MVT::i8 ||
18464-
!TLI.isTypeLegal(A.getValueType()))
18438+
EVT OpVT = A.getValueType();
18439+
if (OpVT.getVectorElementType() != MVT::i8 || !TLI.isTypeLegal(OpVT))
1846518440
return SDValue();
1846618441

1846718442
MVT ResVT = getQDOTXResultType(A.getSimpleValueType());
18468-
A = DAG.getBitcast(ResVT, A);
18469-
SDValue B = DAG.getConstant(0x01010101, DL, ResVT);
18470-
18443+
SDValue B = DAG.getConstant(0x1, DL, OpVT);
1847118444
bool IsSigned = InVec.getOpcode() == ISD::SIGN_EXTEND;
18472-
unsigned Opc = IsSigned ? RISCVISD::VQDOT_VL : RISCVISD::VQDOTU_VL;
18473-
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18445+
unsigned Opc =
18446+
IsSigned ? ISD::PARTIAL_REDUCE_SMLA : ISD::PARTIAL_REDUCE_UMLA;
18447+
return DAG.getNode(Opc, DL, ResVT, {DAG.getConstant(0, DL, ResVT), A, B});
1847418448
}
1847518449

18476-
// mul (sext, sext) -> vqdot
18477-
// mul (zext, zext) -> vqdotu
18478-
// mul (sext, zext) -> vqdotsu
18479-
// mul (zext, sext) -> vqdotsu (swapped)
18480-
// TODO: Improve .vx handling - we end up with a sub-vector insert
18481-
// which confuses the splat pattern matching. Also, match vqdotus.vx
18450+
// mul (sext a, sext b) -> partial_reduce_smla 0, a, b
18451+
// mul (zext a, zext b) -> partial_reduce_umla 0, a, b
18452+
// mul (sext a, zext b) -> partial_reduce_ssmla 0, a, b
18453+
// mul (zext a, sext b) -> partial_reduce_smla 0, b, a (swapped)
1848218454
if (InVec.getOpcode() != ISD::MUL)
1848318455
return SDValue();
1848418456

1848518457
SDValue A = InVec.getOperand(0);
1848618458
SDValue B = InVec.getOperand(1);
18487-
unsigned Opc = 0;
18488-
if (A.getOpcode() == B.getOpcode()) {
18489-
if (A.getOpcode() == ISD::SIGN_EXTEND)
18490-
Opc = RISCVISD::VQDOT_VL;
18491-
else if (A.getOpcode() == ISD::ZERO_EXTEND)
18492-
Opc = RISCVISD::VQDOTU_VL;
18493-
else
18494-
return SDValue();
18495-
} else {
18496-
if (B.getOpcode() != ISD::ZERO_EXTEND)
18497-
std::swap(A, B);
18498-
if (A.getOpcode() != ISD::SIGN_EXTEND || B.getOpcode() != ISD::ZERO_EXTEND)
18499-
return SDValue();
18500-
Opc = RISCVISD::VQDOTSU_VL;
18501-
}
18502-
assert(Opc);
1850318459

18504-
if (A.getOperand(0).getValueType().getVectorElementType() != MVT::i8 ||
18505-
A.getOperand(0).getValueType() != B.getOperand(0).getValueType() ||
18460+
if (!ISD::isExtOpcode(A.getOpcode()))
18461+
return SDValue();
18462+
18463+
EVT OpVT = A.getOperand(0).getValueType();
18464+
if (OpVT.getVectorElementType() != MVT::i8 ||
18465+
OpVT != B.getOperand(0).getValueType() ||
1850618466
!TLI.isTypeLegal(A.getValueType()))
1850718467
return SDValue();
1850818468

18509-
MVT ResVT = getQDOTXResultType(A.getOperand(0).getSimpleValueType());
18510-
A = DAG.getBitcast(ResVT, A.getOperand(0));
18511-
B = DAG.getBitcast(ResVT, B.getOperand(0));
18512-
return lowerVQDOT(Opc, A, B, DL, DAG, Subtarget);
18469+
unsigned Opc;
18470+
if (A.getOpcode() == ISD::SIGN_EXTEND && B.getOpcode() == ISD::SIGN_EXTEND)
18471+
Opc = ISD::PARTIAL_REDUCE_SMLA;
18472+
else if (A.getOpcode() == ISD::ZERO_EXTEND &&
18473+
B.getOpcode() == ISD::ZERO_EXTEND)
18474+
Opc = ISD::PARTIAL_REDUCE_UMLA;
18475+
else if (A.getOpcode() == ISD::SIGN_EXTEND &&
18476+
B.getOpcode() == ISD::ZERO_EXTEND)
18477+
Opc = ISD::PARTIAL_REDUCE_SUMLA;
18478+
else if (A.getOpcode() == ISD::ZERO_EXTEND &&
18479+
B.getOpcode() == ISD::SIGN_EXTEND) {
18480+
Opc = ISD::PARTIAL_REDUCE_SUMLA;
18481+
std::swap(A, B);
18482+
} else
18483+
return SDValue();
18484+
18485+
MVT ResVT = getQDOTXResultType(OpVT.getSimpleVT());
18486+
return DAG.getNode(
18487+
Opc, DL, ResVT,
18488+
{DAG.getConstant(0, DL, ResVT), A.getOperand(0), B.getOperand(0)});
1851318489
}
1851418490

1851518491
static SDValue performVECREDUCECombine(SDNode *N, SelectionDAG &DAG,

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,13 @@ define i32 @reduce_of_sext(<16 x i8> %a) {
232232
;
233233
; DOT-LABEL: reduce_of_sext:
234234
; DOT: # %bb.0: # %entry
235+
; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
236+
; DOT-NEXT: vmv.v.i v9, 1
235237
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
236-
; DOT-NEXT: vmv.v.i v9, 0
237-
; DOT-NEXT: lui a0, 4112
238-
; DOT-NEXT: addi a0, a0, 257
239-
; DOT-NEXT: vqdot.vx v9, v8, a0
238+
; DOT-NEXT: vmv.v.i v10, 0
239+
; DOT-NEXT: vqdot.vv v10, v8, v9
240240
; DOT-NEXT: vmv.s.x v8, zero
241-
; DOT-NEXT: vredsum.vs v8, v9, v8
241+
; DOT-NEXT: vredsum.vs v8, v10, v8
242242
; DOT-NEXT: vmv.x.s a0, v8
243243
; DOT-NEXT: ret
244244
entry:
@@ -259,13 +259,13 @@ define i32 @reduce_of_zext(<16 x i8> %a) {
259259
;
260260
; DOT-LABEL: reduce_of_zext:
261261
; DOT: # %bb.0: # %entry
262+
; DOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
263+
; DOT-NEXT: vmv.v.i v9, 1
262264
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
263-
; DOT-NEXT: vmv.v.i v9, 0
264-
; DOT-NEXT: lui a0, 4112
265-
; DOT-NEXT: addi a0, a0, 257
266-
; DOT-NEXT: vqdotu.vx v9, v8, a0
265+
; DOT-NEXT: vmv.v.i v10, 0
266+
; DOT-NEXT: vqdotu.vv v10, v8, v9
267267
; DOT-NEXT: vmv.s.x v8, zero
268-
; DOT-NEXT: vredsum.vs v8, v9, v8
268+
; DOT-NEXT: vredsum.vs v8, v10, v8
269269
; DOT-NEXT: vmv.x.s a0, v8
270270
; DOT-NEXT: ret
271271
entry:

llvm/test/CodeGen/RISCV/rvv/zvqdotq-sdnode.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,13 @@ define i32 @reduce_of_sext(<vscale x 16 x i8> %a) {
232232
;
233233
; DOT-LABEL: reduce_of_sext:
234234
; DOT: # %bb.0: # %entry
235+
; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
236+
; DOT-NEXT: vmv.v.i v10, 1
235237
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
236-
; DOT-NEXT: vmv.v.i v10, 0
237-
; DOT-NEXT: lui a0, 4112
238-
; DOT-NEXT: addi a0, a0, 257
239-
; DOT-NEXT: vqdot.vx v10, v8, a0
238+
; DOT-NEXT: vmv.v.i v12, 0
239+
; DOT-NEXT: vqdot.vv v12, v8, v10
240240
; DOT-NEXT: vmv.s.x v8, zero
241-
; DOT-NEXT: vredsum.vs v8, v10, v8
241+
; DOT-NEXT: vredsum.vs v8, v12, v8
242242
; DOT-NEXT: vmv.x.s a0, v8
243243
; DOT-NEXT: ret
244244
entry:
@@ -259,13 +259,13 @@ define i32 @reduce_of_zext(<vscale x 16 x i8> %a) {
259259
;
260260
; DOT-LABEL: reduce_of_zext:
261261
; DOT: # %bb.0: # %entry
262+
; DOT-NEXT: vsetvli a0, zero, e8, m2, ta, ma
263+
; DOT-NEXT: vmv.v.i v10, 1
262264
; DOT-NEXT: vsetvli a0, zero, e32, m2, ta, ma
263-
; DOT-NEXT: vmv.v.i v10, 0
264-
; DOT-NEXT: lui a0, 4112
265-
; DOT-NEXT: addi a0, a0, 257
266-
; DOT-NEXT: vqdotu.vx v10, v8, a0
265+
; DOT-NEXT: vmv.v.i v12, 0
266+
; DOT-NEXT: vqdotu.vv v12, v8, v10
267267
; DOT-NEXT: vmv.s.x v8, zero
268-
; DOT-NEXT: vredsum.vs v8, v10, v8
268+
; DOT-NEXT: vredsum.vs v8, v12, v8
269269
; DOT-NEXT: vmv.x.s a0, v8
270270
; DOT-NEXT: ret
271271
entry:

0 commit comments

Comments
 (0)