Skip to content

Commit 15141cd

Browse files
committed
[RISCV] Add RVV insertelt/extractelt scalable-vector patterns
Original patch by @rogfer01. This patch adds support for insertelt and extractelt operations on scalable vectors. Special care must be taken on RV32 when dealing with i64 vectors as there are no straightforward ways to insert a 64-bit element without a register of that size. To that end, both are custom-lowered to different sequences. Authored-by: Roger Ferrer Ibanez <[email protected]> Co-Authored-by: Fraser Cormack <[email protected]> Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D94615
1 parent 1ac36b3 commit 15141cd

11 files changed

+5597
-19
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 149 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -403,12 +403,20 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
403403
// 2. Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
404404
// nodes which truncate by one power of two at a time.
405405
setOperationAction(ISD::TRUNCATE, VT, Custom);
406+
407+
// Custom-lower insert/extract operations to simplify patterns.
408+
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
409+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
406410
}
407411
}
408412

409-
// We must custom-lower SPLAT_VECTOR vXi64 on RV32
410-
if (!Subtarget.is64Bit())
413+
// We must custom-lower certain vXi64 operations on RV32 due to the vector
414+
// element type being illegal.
415+
if (!Subtarget.is64Bit()) {
411416
setOperationAction(ISD::SPLAT_VECTOR, MVT::i64, Custom);
417+
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::i64, Custom);
418+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::i64, Custom);
419+
}
412420

413421
// Expand various CCs to best match the RVV ISA, which natively supports UNE
414422
// but no other unordered comparisons, and supports all ordered comparisons
@@ -423,33 +431,34 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
423431
ISD::SETGT, ISD::SETOGT, ISD::SETGE, ISD::SETOGE,
424432
};
425433

434+
// Sets common operation actions on RVV floating-point vector types.
435+
const auto SetCommonVFPActions = [&](MVT VT) {
436+
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
437+
// Custom-lower insert/extract operations to simplify patterns.
438+
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Custom);
439+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, VT, Custom);
440+
for (auto CC : VFPCCToExpand)
441+
setCondCodeAction(CC, VT, Expand);
442+
};
443+
426444
if (Subtarget.hasStdExtZfh()) {
427445
for (auto VT : {RISCVVMVTs::vfloat16mf4_t, RISCVVMVTs::vfloat16mf2_t,
428446
RISCVVMVTs::vfloat16m1_t, RISCVVMVTs::vfloat16m2_t,
429-
RISCVVMVTs::vfloat16m4_t, RISCVVMVTs::vfloat16m8_t}) {
430-
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
431-
for (auto CC : VFPCCToExpand)
432-
setCondCodeAction(CC, VT, Expand);
433-
}
447+
RISCVVMVTs::vfloat16m4_t, RISCVVMVTs::vfloat16m8_t})
448+
SetCommonVFPActions(VT);
434449
}
435450

436451
if (Subtarget.hasStdExtF()) {
437452
for (auto VT : {RISCVVMVTs::vfloat32mf2_t, RISCVVMVTs::vfloat32m1_t,
438453
RISCVVMVTs::vfloat32m2_t, RISCVVMVTs::vfloat32m4_t,
439-
RISCVVMVTs::vfloat32m8_t}) {
440-
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
441-
for (auto CC : VFPCCToExpand)
442-
setCondCodeAction(CC, VT, Expand);
443-
}
454+
RISCVVMVTs::vfloat32m8_t})
455+
SetCommonVFPActions(VT);
444456
}
445457

446458
if (Subtarget.hasStdExtD()) {
447459
for (auto VT : {RISCVVMVTs::vfloat64m1_t, RISCVVMVTs::vfloat64m2_t,
448-
RISCVVMVTs::vfloat64m4_t, RISCVVMVTs::vfloat64m8_t}) {
449-
setOperationAction(ISD::SPLAT_VECTOR, VT, Legal);
450-
for (auto CC : VFPCCToExpand)
451-
setCondCodeAction(CC, VT, Expand);
452-
}
460+
RISCVVMVTs::vfloat64m4_t, RISCVVMVTs::vfloat64m8_t})
461+
SetCommonVFPActions(VT);
453462
}
454463
}
455464

@@ -761,6 +770,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
761770
return lowerVectorMaskExt(Op, DAG, /*ExtVal*/ -1);
762771
case ISD::SPLAT_VECTOR:
763772
return lowerSPLATVECTOR(Op, DAG);
773+
case ISD::INSERT_VECTOR_ELT:
774+
return lowerINSERT_VECTOR_ELT(Op, DAG);
775+
case ISD::EXTRACT_VECTOR_ELT:
776+
return lowerEXTRACT_VECTOR_ELT(Op, DAG);
764777
case ISD::VSCALE: {
765778
MVT VT = Op.getSimpleValueType();
766779
SDLoc DL(Op);
@@ -1209,6 +1222,12 @@ SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
12091222
DAG.getConstant(CVal->getSExtValue(), DL, MVT::i32));
12101223
}
12111224

1225+
if (SplatVal.getOpcode() == ISD::SIGN_EXTEND &&
1226+
SplatVal.getOperand(0).getValueType() == MVT::i32) {
1227+
return DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
1228+
SplatVal.getOperand(0));
1229+
}
1230+
12121231
// Else, on RV32 we lower an i64-element SPLAT_VECTOR thus, being careful not
12131232
// to accidentally sign-extend the 32-bit halves to the e64 SEW:
12141233
// vmv.v.x vX, hi
@@ -1306,6 +1325,72 @@ SDValue RISCVTargetLowering::lowerVectorMaskTrunc(SDValue Op,
13061325
return DAG.getSetCC(DL, MaskVT, Trunc, SplatZero, ISD::SETNE);
13071326
}
13081327

1328+
SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
1329+
SelectionDAG &DAG) const {
1330+
SDLoc DL(Op);
1331+
EVT VecVT = Op.getValueType();
1332+
SDValue Vec = Op.getOperand(0);
1333+
SDValue Val = Op.getOperand(1);
1334+
SDValue Idx = Op.getOperand(2);
1335+
1336+
// Custom-legalize INSERT_VECTOR_ELT where XLEN>=SEW, so that the vector is
1337+
// first slid down into position, the value is inserted into the first
1338+
// position, and the vector is slid back up. We do this to simplify patterns.
1339+
// (slideup vec, (insertelt (slidedown impdef, vec, idx), val, 0), idx),
1340+
if (Subtarget.is64Bit() || VecVT.getVectorElementType() != MVT::i64) {
1341+
if (isNullConstant(Idx))
1342+
return Op;
1343+
SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT,
1344+
DAG.getUNDEF(VecVT), Vec, Idx);
1345+
SDValue InsertElt0 =
1346+
DAG.getNode(ISD::INSERT_VECTOR_ELT, DL, VecVT, Slidedown, Val,
1347+
DAG.getConstant(0, DL, Subtarget.getXLenVT()));
1348+
1349+
return DAG.getNode(RISCVISD::VSLIDEUP, DL, VecVT, Vec, InsertElt0, Idx);
1350+
}
1351+
1352+
// Custom-legalize INSERT_VECTOR_ELT where XLEN<SEW, as the SEW element type
1353+
// is illegal (currently only vXi64 RV32).
1354+
// Since there is no easy way of getting a single element into a vector when
1355+
// XLEN<SEW, we lower the operation to the following sequence:
1356+
// splat vVal, rVal
1357+
// vid.v vVid
1358+
// vmseq.vx mMask, vVid, rIdx
1359+
// vmerge.vvm vDest, vSrc, vVal, mMask
1360+
// This essentially merges the original vector with the inserted element by
1361+
// using a mask whose only set bit is that corresponding to the insert
1362+
// index.
1363+
SDValue SplattedVal = DAG.getSplatVector(VecVT, DL, Val);
1364+
SDValue SplattedIdx = DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Idx);
1365+
1366+
SDValue VID = DAG.getNode(RISCVISD::VID, DL, VecVT);
1367+
auto SetCCVT =
1368+
getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VecVT);
1369+
SDValue Mask = DAG.getSetCC(DL, SetCCVT, VID, SplattedIdx, ISD::SETEQ);
1370+
1371+
return DAG.getNode(ISD::VSELECT, DL, VecVT, Mask, SplattedVal, Vec);
1372+
}
1373+
1374+
// Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then
1375+
// extract the first element: (extractelt (slidedown vec, idx), 0). This is
1376+
// done to maintain partity with the legalization of RV32 vXi64 legalization.
1377+
SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
1378+
SelectionDAG &DAG) const {
1379+
SDLoc DL(Op);
1380+
SDValue Idx = Op.getOperand(1);
1381+
if (isNullConstant(Idx))
1382+
return Op;
1383+
1384+
SDValue Vec = Op.getOperand(0);
1385+
EVT EltVT = Op.getValueType();
1386+
EVT VecVT = Vec.getValueType();
1387+
SDValue Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT,
1388+
DAG.getUNDEF(VecVT), Vec, Idx);
1389+
1390+
return DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Slidedown,
1391+
DAG.getConstant(0, DL, Subtarget.getXLenVT()));
1392+
}
1393+
13091394
SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
13101395
SelectionDAG &DAG) const {
13111396
unsigned IntNo = cast<ConstantSDNode>(Op.getOperand(0))->getZExtValue();
@@ -1640,6 +1725,44 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
16401725
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i32, NewOp));
16411726
break;
16421727
}
1728+
case ISD::EXTRACT_VECTOR_ELT: {
1729+
// Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element
1730+
// type is illegal (currently only vXi64 RV32).
1731+
// With vmv.x.s, when SEW > XLEN, only the least-significant XLEN bits are
1732+
// transferred to the destination register. We issue two of these from the
1733+
// upper- and lower- halves of the SEW-bit vector element, slid down to the
1734+
// first element.
1735+
SDLoc DL(N);
1736+
SDValue Vec = N->getOperand(0);
1737+
SDValue Idx = N->getOperand(1);
1738+
EVT VecVT = Vec.getValueType();
1739+
assert(!Subtarget.is64Bit() && N->getValueType(0) == MVT::i64 &&
1740+
VecVT.getVectorElementType() == MVT::i64 &&
1741+
"Unexpected EXTRACT_VECTOR_ELT legalization");
1742+
1743+
SDValue Slidedown = Vec;
1744+
// Unless the index is known to be 0, we must slide the vector down to get
1745+
// the desired element into index 0.
1746+
if (!isNullConstant(Idx))
1747+
Slidedown = DAG.getNode(RISCVISD::VSLIDEDOWN, DL, VecVT,
1748+
DAG.getUNDEF(VecVT), Vec, Idx);
1749+
1750+
MVT XLenVT = Subtarget.getXLenVT();
1751+
// Extract the lower XLEN bits of the correct vector element.
1752+
SDValue EltLo = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, Slidedown, Idx);
1753+
1754+
// To extract the upper XLEN bits of the vector element, shift the first
1755+
// element right by 32 bits and re-extract the lower XLEN bits.
1756+
SDValue ThirtyTwoV =
1757+
DAG.getNode(RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
1758+
DAG.getConstant(32, DL, Subtarget.getXLenVT()));
1759+
SDValue LShr32 = DAG.getNode(ISD::SRL, DL, VecVT, Slidedown, ThirtyTwoV);
1760+
1761+
SDValue EltHi = DAG.getNode(RISCVISD::VMV_X_S, DL, XLenVT, LShr32, Idx);
1762+
1763+
Results.push_back(DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64, EltLo, EltHi));
1764+
break;
1765+
}
16431766
case ISD::INTRINSIC_WO_CHAIN: {
16441767
unsigned IntNo = cast<ConstantSDNode>(N->getOperand(0))->getZExtValue();
16451768
switch (IntNo) {
@@ -2231,8 +2354,12 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
22312354
return 33;
22322355
case RISCVISD::VMV_X_S:
22332356
// The number of sign bits of the scalar result is computed by obtaining the
2234-
// element type of the input vector operand, substracting its width from the
2235-
// XLEN, and then adding one (sign bit within the element type).
2357+
// element type of the input vector operand, subtracting its width from the
2358+
// XLEN, and then adding one (sign bit within the element type). If the
2359+
// element type is wider than XLen, the least-significant XLEN bits are
2360+
// taken.
2361+
if (Op.getOperand(0).getScalarValueSizeInBits() > Subtarget.getXLen())
2362+
return 1;
22362363
return Subtarget.getXLen() - Op.getOperand(0).getScalarValueSizeInBits() + 1;
22372364
}
22382365

@@ -3893,6 +4020,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
38934020
NODE_NAME_CASE(VLEFF)
38944021
NODE_NAME_CASE(VLEFF_MASK)
38954022
NODE_NAME_CASE(READ_VL)
4023+
NODE_NAME_CASE(VSLIDEUP)
4024+
NODE_NAME_CASE(VSLIDEDOWN)
4025+
NODE_NAME_CASE(VID)
38964026
}
38974027
// clang-format on
38984028
return nullptr;

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,13 @@ enum NodeType : unsigned {
101101
VLEFF_MASK,
102102
// read vl CSR
103103
READ_VL,
104+
// Matches the semantics of vslideup/vslidedown. The first operand is the
105+
// pass-thru operand, the second is the source vector, and the third is the
106+
// XLenVT index (either constant or non-constant).
107+
VSLIDEUP,
108+
VSLIDEDOWN,
109+
// Matches the semantics of the unmasked vid.v instruction.
110+
VID,
104111
};
105112
} // namespace RISCVISD
106113

@@ -298,6 +305,8 @@ class RISCVTargetLowering : public TargetLowering {
298305
SDValue lowerVectorMaskExt(SDValue Op, SelectionDAG &DAG,
299306
int64_t ExtTrueVal) const;
300307
SDValue lowerVectorMaskTrunc(SDValue Op, SelectionDAG &DAG) const;
308+
SDValue lowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
309+
SDValue lowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
301310
SDValue LowerINTRINSIC_WO_CHAIN(SDValue Op, SelectionDAG &DAG) const;
302311
SDValue LowerINTRINSIC_W_CHAIN(SDValue Op, SelectionDAG &DAG) const;
303312

llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ def riscv_trunc_vector : SDNode<"RISCVISD::TRUNCATE_VECTOR",
3232
SDTypeProfile<1, 1,
3333
[SDTCisVec<0>, SDTCisVec<1>]>>;
3434

35+
class FromFPR32<DAGOperand operand, dag input_dag> {
36+
dag ret = !cond(!eq(!cast<string>(operand), !cast<string>(FPR64)):
37+
(INSERT_SUBREG (IMPLICIT_DEF), input_dag, sub_32),
38+
!eq(!cast<string>(operand), !cast<string>(FPR16)):
39+
(EXTRACT_SUBREG input_dag, sub_16),
40+
!eq(1, 1):
41+
input_dag);
42+
}
43+
3544
// Penalize the generic form with Complexity=1 to give the simm5/uimm5 variants
3645
// precedence
3746
def SplatPat : ComplexPattern<vAny, 1, "selectVSplat", [], [], 1>;
@@ -538,3 +547,101 @@ foreach fvti = AllFloatVectors in {
538547
0, fvti.AVL, fvti.SEW)>;
539548
}
540549
} // Predicates = [HasStdExtV, HasStdExtF]
550+
551+
//===----------------------------------------------------------------------===//
552+
// Vector Element Inserts/Extracts
553+
//===----------------------------------------------------------------------===//
554+
555+
// The built-in TableGen 'extractelt' and 'insertelt' nodes must return the
556+
// same type as the vector element type. On RISC-V, XLenVT is the only legal
557+
// integer type, so for integer inserts/extracts we use a custom node which
558+
// returns XLenVT.
559+
def riscv_insert_vector_elt
560+
: SDNode<"ISD::INSERT_VECTOR_ELT",
561+
SDTypeProfile<1, 3, [SDTCisSameAs<0, 1>, SDTCisVT<2, XLenVT>,
562+
SDTCisPtrTy<3>]>, []>;
563+
def riscv_extract_vector_elt
564+
: SDNode<"ISD::EXTRACT_VECTOR_ELT",
565+
SDTypeProfile<1, 2, [SDTCisVT<0, XLenVT>, SDTCisPtrTy<2>]>, []>;
566+
567+
multiclass VPatInsertExtractElt_XI_Idx<bit IsFloat> {
568+
defvar vtilist = !if(IsFloat, AllFloatVectors, AllIntegerVectors);
569+
defvar insertelt_node = !if(IsFloat, insertelt, riscv_insert_vector_elt);
570+
defvar extractelt_node = !if(IsFloat, extractelt, riscv_extract_vector_elt);
571+
foreach vti = vtilist in {
572+
defvar MX = vti.LMul.MX;
573+
defvar vmv_xf_s_inst = !cast<Instruction>(!if(IsFloat, "PseudoVFMV_F_S_",
574+
"PseudoVMV_X_S_")#MX);
575+
defvar vmv_s_xf_inst = !cast<Instruction>(!if(IsFloat, "PseudoVFMV_S_F_",
576+
"PseudoVMV_S_X_")#MX);
577+
// Only pattern-match insert/extract-element operations where the index is
578+
// 0. Any other index will have been custom-lowered to slide the vector
579+
// correctly into place (and, in the case of insert, slide it back again
580+
// afterwards).
581+
def : Pat<(vti.Scalar (extractelt_node (vti.Vector vti.RegClass:$rs2), 0)),
582+
FromFPR32<vti.ScalarRegClass,
583+
(vmv_xf_s_inst vti.RegClass:$rs2, vti.SEW)>.ret>;
584+
585+
def : Pat<(vti.Vector (insertelt_node (vti.Vector vti.RegClass:$merge),
586+
vti.ScalarRegClass:$rs1, 0)),
587+
(vmv_s_xf_inst vti.RegClass:$merge,
588+
ToFPR32<vti.Scalar, vti.ScalarRegClass, "rs1">.ret,
589+
vti.AVL, vti.SEW)>;
590+
}
591+
}
592+
593+
let Predicates = [HasStdExtV] in
594+
defm "" : VPatInsertExtractElt_XI_Idx</*IsFloat*/0>;
595+
let Predicates = [HasStdExtV, HasStdExtF] in
596+
defm "" : VPatInsertExtractElt_XI_Idx</*IsFloat*/1>;
597+
598+
//===----------------------------------------------------------------------===//
599+
// Miscellaneous RISCVISD SDNodes
600+
//===----------------------------------------------------------------------===//
601+
602+
def riscv_vid
603+
: SDNode<"RISCVISD::VID", SDTypeProfile<1, 0, [SDTCisVec<0>]>, []>;
604+
605+
def SDTRVVSlide : SDTypeProfile<1, 3, [
606+
SDTCisVec<0>, SDTCisSameAs<1, 0>, SDTCisSameAs<2, 0>, SDTCisVT<3, XLenVT>
607+
]>;
608+
609+
def riscv_slideup : SDNode<"RISCVISD::VSLIDEUP", SDTRVVSlide, []>;
610+
def riscv_slidedown : SDNode<"RISCVISD::VSLIDEDOWN", SDTRVVSlide, []>;
611+
612+
let Predicates = [HasStdExtV] in {
613+
614+
foreach vti = AllIntegerVectors in
615+
def : Pat<(vti.Vector riscv_vid),
616+
(!cast<Instruction>("PseudoVID_V_"#vti.LMul.MX) vti.AVL, vti.SEW)>;
617+
618+
foreach vti = !listconcat(AllIntegerVectors, AllFloatVectors) in {
619+
def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3),
620+
(vti.Vector vti.RegClass:$rs1),
621+
uimm5:$rs2)),
622+
(!cast<Instruction>("PseudoVSLIDEUP_VI_"#vti.LMul.MX)
623+
vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2,
624+
vti.AVL, vti.SEW)>;
625+
626+
def : Pat<(vti.Vector (riscv_slideup (vti.Vector vti.RegClass:$rs3),
627+
(vti.Vector vti.RegClass:$rs1),
628+
GPR:$rs2)),
629+
(!cast<Instruction>("PseudoVSLIDEUP_VX_"#vti.LMul.MX)
630+
vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2,
631+
vti.AVL, vti.SEW)>;
632+
633+
def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3),
634+
(vti.Vector vti.RegClass:$rs1),
635+
uimm5:$rs2)),
636+
(!cast<Instruction>("PseudoVSLIDEDOWN_VI_"#vti.LMul.MX)
637+
vti.RegClass:$rs3, vti.RegClass:$rs1, uimm5:$rs2,
638+
vti.AVL, vti.SEW)>;
639+
640+
def : Pat<(vti.Vector (riscv_slidedown (vti.Vector vti.RegClass:$rs3),
641+
(vti.Vector vti.RegClass:$rs1),
642+
GPR:$rs2)),
643+
(!cast<Instruction>("PseudoVSLIDEDOWN_VX_"#vti.LMul.MX)
644+
vti.RegClass:$rs3, vti.RegClass:$rs1, GPR:$rs2,
645+
vti.AVL, vti.SEW)>;
646+
}
647+
} // Predicates = [HasStdExtV]

0 commit comments

Comments
 (0)