Skip to content

Commit 6f17613

Browse files
committed
[RISCV][VP] Lower VP ISD nodes to RVV instructions
This patch supports all of the current set of VP integer binary intrinsics by lowering them to to RVV instructions. It does so by using the existing RISCVISD *_VL custom nodes as an intermediate layer. Both scalable and fixed-length vectors are supported by using this method. One notable change to the existing vector codegen strategy is that scalable all-ones and all-zeros mask SPLAT_VECTORs are now lowered to RISCVISD VMSET_VL and VMCLR_VL nodes to match their fixed-length BUILD_VECTOR counterparts. This allows them to reuse the existing "all-ones" VL patterns. To reduce the size of the phabricator diff, some tests are intentionally left out and will be added later if the patch is accepted. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D101826
1 parent 2865d11 commit 6f17613

13 files changed

+10872
-74
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,11 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
411411

412412
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
413413

414+
static unsigned IntegerVPOps[] = {
415+
ISD::VP_ADD, ISD::VP_SUB, ISD::VP_MUL, ISD::VP_SDIV, ISD::VP_UDIV,
416+
ISD::VP_SREM, ISD::VP_UREM, ISD::VP_AND, ISD::VP_OR, ISD::VP_XOR,
417+
ISD::VP_ASHR, ISD::VP_LSHR, ISD::VP_SHL};
418+
414419
if (!Subtarget.is64Bit()) {
415420
// We must custom-lower certain vXi64 operations on RV32 due to the vector
416421
// element type being illegal.
@@ -496,6 +501,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
496501
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
497502
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
498503

504+
for (unsigned VPOpc : IntegerVPOps) {
505+
setOperationAction(VPOpc, VT, Custom);
506+
// RV64 must custom-legalize the i32 EVL parameter.
507+
if (Subtarget.is64Bit())
508+
setOperationAction(VPOpc, MVT::i32, Custom);
509+
}
510+
499511
setOperationAction(ISD::MLOAD, VT, Custom);
500512
setOperationAction(ISD::MSTORE, VT, Custom);
501513
setOperationAction(ISD::MGATHER, VT, Custom);
@@ -695,6 +707,13 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
695707
setOperationAction(ISD::VECREDUCE_SMIN, VT, Custom);
696708
setOperationAction(ISD::VECREDUCE_UMAX, VT, Custom);
697709
setOperationAction(ISD::VECREDUCE_UMIN, VT, Custom);
710+
711+
for (unsigned VPOpc : IntegerVPOps) {
712+
setOperationAction(VPOpc, VT, Custom);
713+
// RV64 must custom-legalize the i32 EVL parameter.
714+
if (Subtarget.is64Bit())
715+
setOperationAction(VPOpc, MVT::i32, Custom);
716+
}
698717
}
699718

700719
for (MVT VT : MVT::fp_fixedlen_vector_valuetypes()) {
@@ -2367,6 +2386,32 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
23672386
return lowerGET_ROUNDING(Op, DAG);
23682387
case ISD::SET_ROUNDING:
23692388
return lowerSET_ROUNDING(Op, DAG);
2389+
case ISD::VP_ADD:
2390+
return lowerVPOp(Op, DAG, RISCVISD::ADD_VL);
2391+
case ISD::VP_SUB:
2392+
return lowerVPOp(Op, DAG, RISCVISD::SUB_VL);
2393+
case ISD::VP_MUL:
2394+
return lowerVPOp(Op, DAG, RISCVISD::MUL_VL);
2395+
case ISD::VP_SDIV:
2396+
return lowerVPOp(Op, DAG, RISCVISD::SDIV_VL);
2397+
case ISD::VP_UDIV:
2398+
return lowerVPOp(Op, DAG, RISCVISD::UDIV_VL);
2399+
case ISD::VP_SREM:
2400+
return lowerVPOp(Op, DAG, RISCVISD::SREM_VL);
2401+
case ISD::VP_UREM:
2402+
return lowerVPOp(Op, DAG, RISCVISD::UREM_VL);
2403+
case ISD::VP_AND:
2404+
return lowerVPOp(Op, DAG, RISCVISD::AND_VL);
2405+
case ISD::VP_OR:
2406+
return lowerVPOp(Op, DAG, RISCVISD::OR_VL);
2407+
case ISD::VP_XOR:
2408+
return lowerVPOp(Op, DAG, RISCVISD::XOR_VL);
2409+
case ISD::VP_ASHR:
2410+
return lowerVPOp(Op, DAG, RISCVISD::SRA_VL);
2411+
case ISD::VP_LSHR:
2412+
return lowerVPOp(Op, DAG, RISCVISD::SRL_VL);
2413+
case ISD::VP_SHL:
2414+
return lowerVPOp(Op, DAG, RISCVISD::SHL_VL);
23702415
}
23712416
}
23722417

@@ -2828,12 +2873,18 @@ SDValue RISCVTargetLowering::lowerShiftRightParts(SDValue Op, SelectionDAG &DAG,
28282873
// legal equivalently-sized i8 type, so we can use that as a go-between.
28292874
SDValue RISCVTargetLowering::lowerVectorMaskSplat(SDValue Op,
28302875
SelectionDAG &DAG) const {
2831-
SDValue SplatVal = Op.getOperand(0);
2832-
// All-zeros or all-ones splats are handled specially.
2833-
if (isa<ConstantSDNode>(SplatVal))
2834-
return Op;
28352876
SDLoc DL(Op);
28362877
MVT VT = Op.getSimpleValueType();
2878+
SDValue SplatVal = Op.getOperand(0);
2879+
// All-zeros or all-ones splats are handled specially.
2880+
if (ISD::isConstantSplatVectorAllOnes(Op.getNode())) {
2881+
SDValue VL = getDefaultScalableVLOps(VT, DL, DAG, Subtarget).second;
2882+
return DAG.getNode(RISCVISD::VMSET_VL, DL, VT, VL);
2883+
}
2884+
if (ISD::isConstantSplatVectorAllZeros(Op.getNode())) {
2885+
SDValue VL = getDefaultScalableVLOps(VT, DL, DAG, Subtarget).second;
2886+
return DAG.getNode(RISCVISD::VMCLR_VL, DL, VT, VL);
2887+
}
28372888
MVT XLenVT = Subtarget.getXLenVT();
28382889
assert(SplatVal.getValueType() == XLenVT &&
28392890
"Unexpected type for i1 splat value");
@@ -4215,6 +4266,50 @@ SDValue RISCVTargetLowering::lowerToScalableOp(SDValue Op, SelectionDAG &DAG,
42154266
return convertFromScalableVector(VT, ScalableRes, DAG, Subtarget);
42164267
}
42174268

4269+
// Lower a VP_* ISD node to the corresponding RISCVISD::*_VL node:
4270+
// * Operands of each node are assumed to be in the same order.
4271+
// * The EVL operand is promoted from i32 to i64 on RV64.
4272+
// * Fixed-length vectors are converted to their scalable-vector container
4273+
// types.
4274+
SDValue RISCVTargetLowering::lowerVPOp(SDValue Op, SelectionDAG &DAG,
4275+
unsigned RISCVISDOpc) const {
4276+
SDLoc DL(Op);
4277+
MVT VT = Op.getSimpleValueType();
4278+
Optional<unsigned> EVLIdx = ISD::getVPExplicitVectorLengthIdx(Op.getOpcode());
4279+
4280+
SmallVector<SDValue, 4> Ops;
4281+
MVT XLenVT = Subtarget.getXLenVT();
4282+
4283+
for (const auto &OpIdx : enumerate(Op->ops())) {
4284+
SDValue V = OpIdx.value();
4285+
if ((unsigned)OpIdx.index() == EVLIdx) {
4286+
Ops.push_back(DAG.getZExtOrTrunc(V, DL, XLenVT));
4287+
continue;
4288+
}
4289+
assert(!isa<VTSDNode>(V) && "Unexpected VTSDNode node!");
4290+
// Pass through operands which aren't fixed-length vectors.
4291+
if (!V.getValueType().isFixedLengthVector()) {
4292+
Ops.push_back(V);
4293+
continue;
4294+
}
4295+
// "cast" fixed length vector to a scalable vector.
4296+
MVT OpVT = V.getSimpleValueType();
4297+
MVT ContainerVT = getContainerForFixedLengthVector(OpVT);
4298+
assert(useRVVForFixedLengthVectorVT(OpVT) &&
4299+
"Only fixed length vectors are supported!");
4300+
Ops.push_back(convertToScalableVector(ContainerVT, V, DAG, Subtarget));
4301+
}
4302+
4303+
if (!VT.isFixedLengthVector())
4304+
return DAG.getNode(RISCVISDOpc, DL, VT, Ops);
4305+
4306+
MVT ContainerVT = getContainerForFixedLengthVector(VT);
4307+
4308+
SDValue VPOp = DAG.getNode(RISCVISDOpc, DL, ContainerVT, Ops);
4309+
4310+
return convertFromScalableVector(VT, VPOp, DAG, Subtarget);
4311+
}
4312+
42184313
// Custom lower MGATHER to a legalized form for RVV. It will then be matched to
42194314
// a RVV indexed load. The RVV indexed load instructions only support the
42204315
// "unsigned unscaled" addressing mode; indices are implicitly zero-extended or

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,7 @@ class RISCVTargetLowering : public TargetLowering {
543543
SelectionDAG &DAG) const;
544544
SDValue lowerToScalableOp(SDValue Op, SelectionDAG &DAG, unsigned NewOpc,
545545
bool HasMask = true) const;
546+
SDValue lowerVPOp(SDValue Op, SelectionDAG &DAG, unsigned RISCVISDOpc) const;
546547
SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG,
547548
unsigned ExtendOpc) const;
548549
SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ def SDTSplatI64 : SDTypeProfile<1, 1, [
2828

2929
def rv32_splat_i64 : SDNode<"RISCVISD::SPLAT_VECTOR_I64", SDTSplatI64>;
3030

31+
def SDT_RISCVVMSETCLR_VL : SDTypeProfile<1, 1, [SDTCVecEltisVT<0, i1>,
32+
SDTCisVT<1, XLenVT>]>;
33+
def riscv_vmclr_vl : SDNode<"RISCVISD::VMCLR_VL", SDT_RISCVVMSETCLR_VL>;
34+
def riscv_vmset_vl : SDNode<"RISCVISD::VMSET_VL", SDT_RISCVVMSETCLR_VL>;
35+
36+
def rvv_vnot : PatFrag<(ops node:$in),
37+
(xor node:$in, (riscv_vmset_vl (XLenVT srcvalue)))>;
38+
3139
// Give explicit Complexity to prefer simm5/uimm5.
3240
def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [splat_vector, rv32_splat_i64], [], 1>;
3341
def SplatPat_simm5 : ComplexPattern<vAny, 1, "selectVSplatSimm5", [splat_vector, rv32_splat_i64], [], 2>;
@@ -503,25 +511,25 @@ foreach mti = AllMasks in {
503511
(!cast<Instruction>("PseudoVMXOR_MM_"#mti.LMul.MX)
504512
VR:$rs1, VR:$rs2, mti.AVL, mti.Log2SEW)>;
505513

506-
def : Pat<(mti.Mask (vnot (and VR:$rs1, VR:$rs2))),
514+
def : Pat<(mti.Mask (rvv_vnot (and VR:$rs1, VR:$rs2))),
507515
(!cast<Instruction>("PseudoVMNAND_MM_"#mti.LMul.MX)
508516
VR:$rs1, VR:$rs2, mti.AVL, mti.Log2SEW)>;
509-
def : Pat<(mti.Mask (vnot (or VR:$rs1, VR:$rs2))),
517+
def : Pat<(mti.Mask (rvv_vnot (or VR:$rs1, VR:$rs2))),
510518
(!cast<Instruction>("PseudoVMNOR_MM_"#mti.LMul.MX)
511519
VR:$rs1, VR:$rs2, mti.AVL, mti.Log2SEW)>;
512-
def : Pat<(mti.Mask (vnot (xor VR:$rs1, VR:$rs2))),
520+
def : Pat<(mti.Mask (rvv_vnot (xor VR:$rs1, VR:$rs2))),
513521
(!cast<Instruction>("PseudoVMXNOR_MM_"#mti.LMul.MX)
514522
VR:$rs1, VR:$rs2, mti.AVL, mti.Log2SEW)>;
515523

516-
def : Pat<(mti.Mask (and VR:$rs1, (vnot VR:$rs2))),
524+
def : Pat<(mti.Mask (and VR:$rs1, (rvv_vnot VR:$rs2))),
517525
(!cast<Instruction>("PseudoVMANDNOT_MM_"#mti.LMul.MX)
518526
VR:$rs1, VR:$rs2, mti.AVL, mti.Log2SEW)>;
519-
def : Pat<(mti.Mask (or VR:$rs1, (vnot VR:$rs2))),
527+
def : Pat<(mti.Mask (or VR:$rs1, (rvv_vnot VR:$rs2))),
520528
(!cast<Instruction>("PseudoVMORNOT_MM_"#mti.LMul.MX)
521529
VR:$rs1, VR:$rs2, mti.AVL, mti.Log2SEW)>;
522530

523-
// Handle vnot the same as the vnot.mm pseudoinstruction.
524-
def : Pat<(mti.Mask (vnot VR:$rs)),
531+
// Handle rvv_vnot the same as the vnot.mm pseudoinstruction.
532+
def : Pat<(mti.Mask (rvv_vnot VR:$rs)),
525533
(!cast<Instruction>("PseudoVMNAND_MM_"#mti.LMul.MX)
526534
VR:$rs, VR:$rs, mti.AVL, mti.Log2SEW)>;
527535
}
@@ -725,13 +733,6 @@ foreach vti = AllIntegerVectors in {
725733
(!cast<Instruction>("PseudoVMV_V_I_" # vti.LMul.MX)
726734
simm5:$rs1, vti.AVL, vti.Log2SEW)>;
727735
}
728-
729-
foreach mti = AllMasks in {
730-
def : Pat<(mti.Mask immAllOnesV),
731-
(!cast<Instruction>("PseudoVMSET_M_"#mti.BX) mti.AVL, mti.Log2SEW)>;
732-
def : Pat<(mti.Mask immAllZerosV),
733-
(!cast<Instruction>("PseudoVMCLR_M_"#mti.BX) mti.AVL, mti.Log2SEW)>;
734-
}
735736
} // Predicates = [HasStdExtV]
736737

737738
let Predicates = [HasStdExtV, HasStdExtF] in {

0 commit comments

Comments
 (0)