Skip to content

Commit 2b5e531

Browse files
committed
[RISCV] Add support for matching vwmul(u) and vwmacc(u) from fixed vectors.
This adds a DAG combine to detect sext/zext inputs and emit a new ISD opcode. The extends will either be removed or replaced with narrower extends. Isel patterns are used to match add and widening mul to vwmacc similar to the recently added vmacc patterns. There's still some work to be to match vmulsu. We should also rewrite splats that were extended as scalars and then splatted. Reviewed By: arcbbb Differential Revision: https://reviews.llvm.org/D104802
1 parent 846a530 commit 2b5e531

File tree

7 files changed

+2534
-0
lines changed

7 files changed

+2534
-0
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6140,6 +6140,47 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
61406140
}
61416141
break;
61426142
}
6143+
case RISCVISD::MUL_VL: {
6144+
// Try to form VWMUL or VWMULU.
6145+
// FIXME: Look for splat of extended scalar as well.
6146+
// FIXME: Support VWMULSU.
6147+
SDValue Op0 = N->getOperand(0);
6148+
SDValue Op1 = N->getOperand(1);
6149+
bool IsSignExt = Op0.getOpcode() == RISCVISD::VSEXT_VL;
6150+
bool IsZeroExt = Op0.getOpcode() == RISCVISD::VZEXT_VL;
6151+
if ((!IsSignExt && !IsZeroExt) || Op0.getOpcode() != Op1.getOpcode())
6152+
return SDValue();
6153+
6154+
// Make sure the extends have a single use.
6155+
if (!Op0.hasOneUse() || !Op1.hasOneUse())
6156+
return SDValue();
6157+
6158+
SDValue Mask = N->getOperand(2);
6159+
SDValue VL = N->getOperand(3);
6160+
if (Op0.getOperand(1) != Mask || Op1.getOperand(1) != Mask ||
6161+
Op0.getOperand(2) != VL || Op1.getOperand(2) != VL)
6162+
return SDValue();
6163+
6164+
Op0 = Op0.getOperand(0);
6165+
Op1 = Op1.getOperand(0);
6166+
6167+
MVT VT = N->getSimpleValueType(0);
6168+
MVT NarrowVT =
6169+
MVT::getVectorVT(MVT::getIntegerVT(VT.getScalarSizeInBits() / 2),
6170+
VT.getVectorElementCount());
6171+
6172+
SDLoc DL(N);
6173+
6174+
// Re-introduce narrower extends if needed.
6175+
unsigned ExtOpc = IsSignExt ? RISCVISD::VSEXT_VL : RISCVISD::VZEXT_VL;
6176+
if (Op0.getValueType() != NarrowVT)
6177+
Op0 = DAG.getNode(ExtOpc, DL, NarrowVT, Op0, Mask, VL);
6178+
if (Op1.getValueType() != NarrowVT)
6179+
Op1 = DAG.getNode(ExtOpc, DL, NarrowVT, Op1, Mask, VL);
6180+
6181+
unsigned WMulOpc = IsSignExt ? RISCVISD::VWMUL_VL : RISCVISD::VWMULU_VL;
6182+
return DAG.getNode(WMulOpc, DL, VT, Op0, Op1, Mask, VL);
6183+
}
61436184
}
61446185

61456186
return SDValue();
@@ -8199,6 +8240,8 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
81998240
NODE_NAME_CASE(UINT_TO_FP_VL)
82008241
NODE_NAME_CASE(FP_EXTEND_VL)
82018242
NODE_NAME_CASE(FP_ROUND_VL)
8243+
NODE_NAME_CASE(VWMUL_VL)
8244+
NODE_NAME_CASE(VWMULU_VL)
82028245
NODE_NAME_CASE(SETCC_VL)
82038246
NODE_NAME_CASE(VSELECT_VL)
82048247
NODE_NAME_CASE(VMAND_VL)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ enum NodeType : unsigned {
216216
FP_ROUND_VL,
217217
FP_EXTEND_VL,
218218

219+
// Widening instructions
220+
VWMUL_VL,
221+
VWMULU_VL,
222+
219223
// Vector compare producing a mask. Fourth operand is input mask. Fifth
220224
// operand is VL.
221225
SETCC_VL,
@@ -241,6 +245,7 @@ enum NodeType : unsigned {
241245
// Vector sign/zero extend with additional mask & VL operands.
242246
VSEXT_VL,
243247
VZEXT_VL,
248+
244249
// vpopc.m with additional mask and VL operands.
245250
VPOPC_VL,
246251

llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,15 @@ def riscv_trunc_vector_vl : SDNode<"RISCVISD::TRUNCATE_VECTOR_VL",
215215
SDTCVecEltisVT<2, i1>,
216216
SDTCisVT<3, XLenVT>]>>;
217217

218+
def SDT_RISCVVWMUL_VL : SDTypeProfile<1, 4, [SDTCisVec<0>,
219+
SDTCisSameNumEltsAs<0, 1>,
220+
SDTCisSameAs<1, 2>,
221+
SDTCisSameNumEltsAs<1, 3>,
222+
SDTCVecEltisVT<3, i1>,
223+
SDTCisVT<4, XLenVT>]>;
224+
def riscv_vwmul_vl : SDNode<"RISCVISD::VWMUL_VL", SDT_RISCVVWMUL_VL, [SDNPCommutative]>;
225+
def riscv_vwmulu_vl : SDNode<"RISCVISD::VWMULU_VL", SDT_RISCVVWMUL_VL, [SDNPCommutative]>;
226+
218227
def SDTRVVVecReduce : SDTypeProfile<1, 4, [
219228
SDTCisVec<0>, SDTCisVec<1>, SDTCisSameAs<0, 2>, SDTCVecEltisVT<3, i1>,
220229
SDTCisSameNumEltsAs<1, 3>, SDTCisVT<4, XLenVT>
@@ -226,6 +235,18 @@ def riscv_mul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D),
226235
return N->hasOneUse();
227236
}]>;
228237

238+
def riscv_vwmul_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D),
239+
(riscv_vwmul_vl node:$A, node:$B, node:$C,
240+
node:$D), [{
241+
return N->hasOneUse();
242+
}]>;
243+
244+
def riscv_vwmulu_vl_oneuse : PatFrag<(ops node:$A, node:$B, node:$C, node:$D),
245+
(riscv_vwmulu_vl node:$A, node:$B, node:$C,
246+
node:$D), [{
247+
return N->hasOneUse();
248+
}]>;
249+
229250
foreach kind = ["ADD", "UMAX", "SMAX", "UMIN", "SMIN", "AND", "OR", "XOR",
230251
"FADD", "SEQ_FADD", "FMIN", "FMAX"] in
231252
def rvv_vecreduce_#kind#_vl : SDNode<"RISCVISD::VECREDUCE_"#kind#"_VL", SDTRVVVecReduce>;
@@ -326,6 +347,20 @@ multiclass VPatBinaryVL_VV_VX_VI<SDNode vop, string instruction_name,
326347
}
327348
}
328349

350+
multiclass VPatBinaryWVL_VV_VX<SDNode vop, string instruction_name> {
351+
foreach VtiToWti = AllWidenableIntVectors in {
352+
defvar vti = VtiToWti.Vti;
353+
defvar wti = VtiToWti.Wti;
354+
defm : VPatBinaryVL_VV<vop, instruction_name,
355+
wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
356+
vti.LMul, wti.RegClass, vti.RegClass>;
357+
defm : VPatBinaryVL_XI<vop, instruction_name, "VX",
358+
wti.Vector, vti.Vector, vti.Mask, vti.Log2SEW,
359+
vti.LMul, wti.RegClass, vti.RegClass,
360+
SplatPat, GPR>;
361+
}
362+
}
363+
329364
class VPatBinaryVL_VF<SDNode vop,
330365
string instruction_name,
331366
ValueType result_type,
@@ -737,6 +772,10 @@ defm : VPatBinaryVL_VV_VX<riscv_sdiv_vl, "PseudoVDIV">;
737772
defm : VPatBinaryVL_VV_VX<riscv_urem_vl, "PseudoVREMU">;
738773
defm : VPatBinaryVL_VV_VX<riscv_srem_vl, "PseudoVREM">;
739774

775+
// 12.12. Vector Widening Integer Multiply Instructions
776+
defm : VPatBinaryWVL_VV_VX<riscv_vwmul_vl, "PseudoVWMUL">;
777+
defm : VPatBinaryWVL_VV_VX<riscv_vwmulu_vl, "PseudoVWMULU">;
778+
740779
// 12.13 Vector Single-Width Integer Multiply-Add Instructions
741780
foreach vti = AllIntegerVectors in {
742781
// NOTE: We choose VMADD because it has the most commuting freedom. So it
@@ -784,6 +823,49 @@ foreach vti = AllIntegerVectors in {
784823
GPR:$vl, vti.Log2SEW)>;
785824
}
786825

826+
// 12.14. Vector Widening Integer Multiply-Add Instructions
827+
foreach vtiTowti = AllWidenableIntVectors in {
828+
defvar vti = vtiTowti.Vti;
829+
defvar wti = vtiTowti.Wti;
830+
def : Pat<(wti.Vector
831+
(riscv_add_vl wti.RegClass:$rd,
832+
(riscv_vwmul_vl_oneuse vti.RegClass:$rs1,
833+
(vti.Vector vti.RegClass:$rs2),
834+
(vti.Mask true_mask), VLOpFrag),
835+
(vti.Mask true_mask), VLOpFrag)),
836+
(!cast<Instruction>("PseudoVWMACC_VV_"# vti.LMul.MX)
837+
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
838+
GPR:$vl, vti.Log2SEW)>;
839+
def : Pat<(wti.Vector
840+
(riscv_add_vl wti.RegClass:$rd,
841+
(riscv_vwmulu_vl_oneuse vti.RegClass:$rs1,
842+
(vti.Vector vti.RegClass:$rs2),
843+
(vti.Mask true_mask), VLOpFrag),
844+
(vti.Mask true_mask), VLOpFrag)),
845+
(!cast<Instruction>("PseudoVWMACCU_VV_"# vti.LMul.MX)
846+
wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2,
847+
GPR:$vl, vti.Log2SEW)>;
848+
849+
def : Pat<(wti.Vector
850+
(riscv_add_vl wti.RegClass:$rd,
851+
(riscv_vwmul_vl_oneuse (SplatPat XLenVT:$rs1),
852+
(vti.Vector vti.RegClass:$rs2),
853+
(vti.Mask true_mask), VLOpFrag),
854+
(vti.Mask true_mask), VLOpFrag)),
855+
(!cast<Instruction>("PseudoVWMACC_VX_" # vti.LMul.MX)
856+
wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
857+
GPR:$vl, vti.Log2SEW)>;
858+
def : Pat<(wti.Vector
859+
(riscv_add_vl wti.RegClass:$rd,
860+
(riscv_vwmulu_vl_oneuse (SplatPat XLenVT:$rs1),
861+
(vti.Vector vti.RegClass:$rs2),
862+
(vti.Mask true_mask), VLOpFrag),
863+
(vti.Mask true_mask), VLOpFrag)),
864+
(!cast<Instruction>("PseudoVWMACCU_VX_" # vti.LMul.MX)
865+
wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2,
866+
GPR:$vl, vti.Log2SEW)>;
867+
}
868+
787869
// 12.15. Vector Integer Merge Instructions
788870
foreach vti = AllIntegerVectors in {
789871
def : Pat<(vti.Vector (riscv_vselect_vl (vti.Mask VMV0:$vm),

0 commit comments

Comments
 (0)