@@ -403,12 +403,20 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
403
403
// 2. Integer VTs are lowered as a series of "RISCVISD::TRUNCATE_VECTOR"
404
404
// nodes which truncate by one power of two at a time.
405
405
setOperationAction (ISD::TRUNCATE, VT, Custom);
406
+
407
+ // Custom-lower insert/extract operations to simplify patterns.
408
+ setOperationAction (ISD::INSERT_VECTOR_ELT, VT, Custom);
409
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, VT, Custom);
406
410
}
407
411
}
408
412
409
- // We must custom-lower SPLAT_VECTOR vXi64 on RV32
410
- if (!Subtarget.is64Bit ())
413
+ // We must custom-lower certain vXi64 operations on RV32 due to the vector
414
+ // element type being illegal.
415
+ if (!Subtarget.is64Bit ()) {
411
416
setOperationAction (ISD::SPLAT_VECTOR, MVT::i64 , Custom);
417
+ setOperationAction (ISD::INSERT_VECTOR_ELT, MVT::i64 , Custom);
418
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::i64 , Custom);
419
+ }
412
420
413
421
// Expand various CCs to best match the RVV ISA, which natively supports UNE
414
422
// but no other unordered comparisons, and supports all ordered comparisons
@@ -423,33 +431,34 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
423
431
ISD::SETGT, ISD::SETOGT, ISD::SETGE, ISD::SETOGE,
424
432
};
425
433
434
+ // Sets common operation actions on RVV floating-point vector types.
435
+ const auto SetCommonVFPActions = [&](MVT VT) {
436
+ setOperationAction (ISD::SPLAT_VECTOR, VT, Legal);
437
+ // Custom-lower insert/extract operations to simplify patterns.
438
+ setOperationAction (ISD::INSERT_VECTOR_ELT, VT, Custom);
439
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, VT, Custom);
440
+ for (auto CC : VFPCCToExpand)
441
+ setCondCodeAction (CC, VT, Expand);
442
+ };
443
+
426
444
if (Subtarget.hasStdExtZfh ()) {
427
445
for (auto VT : {RISCVVMVTs::vfloat16mf4_t , RISCVVMVTs::vfloat16mf2_t ,
428
446
RISCVVMVTs::vfloat16m1_t , RISCVVMVTs::vfloat16m2_t ,
429
- RISCVVMVTs::vfloat16m4_t , RISCVVMVTs::vfloat16m8_t }) {
430
- setOperationAction (ISD::SPLAT_VECTOR, VT, Legal);
431
- for (auto CC : VFPCCToExpand)
432
- setCondCodeAction (CC, VT, Expand);
433
- }
447
+ RISCVVMVTs::vfloat16m4_t , RISCVVMVTs::vfloat16m8_t })
448
+ SetCommonVFPActions (VT);
434
449
}
435
450
436
451
if (Subtarget.hasStdExtF ()) {
437
452
for (auto VT : {RISCVVMVTs::vfloat32mf2_t , RISCVVMVTs::vfloat32m1_t ,
438
453
RISCVVMVTs::vfloat32m2_t , RISCVVMVTs::vfloat32m4_t ,
439
- RISCVVMVTs::vfloat32m8_t }) {
440
- setOperationAction (ISD::SPLAT_VECTOR, VT, Legal);
441
- for (auto CC : VFPCCToExpand)
442
- setCondCodeAction (CC, VT, Expand);
443
- }
454
+ RISCVVMVTs::vfloat32m8_t })
455
+ SetCommonVFPActions (VT);
444
456
}
445
457
446
458
if (Subtarget.hasStdExtD ()) {
447
459
for (auto VT : {RISCVVMVTs::vfloat64m1_t , RISCVVMVTs::vfloat64m2_t ,
448
- RISCVVMVTs::vfloat64m4_t , RISCVVMVTs::vfloat64m8_t }) {
449
- setOperationAction (ISD::SPLAT_VECTOR, VT, Legal);
450
- for (auto CC : VFPCCToExpand)
451
- setCondCodeAction (CC, VT, Expand);
452
- }
460
+ RISCVVMVTs::vfloat64m4_t , RISCVVMVTs::vfloat64m8_t })
461
+ SetCommonVFPActions (VT);
453
462
}
454
463
}
455
464
@@ -761,6 +770,10 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
761
770
return lowerVectorMaskExt (Op, DAG, /* ExtVal*/ -1 );
762
771
case ISD::SPLAT_VECTOR:
763
772
return lowerSPLATVECTOR (Op, DAG);
773
+ case ISD::INSERT_VECTOR_ELT:
774
+ return lowerINSERT_VECTOR_ELT (Op, DAG);
775
+ case ISD::EXTRACT_VECTOR_ELT:
776
+ return lowerEXTRACT_VECTOR_ELT (Op, DAG);
764
777
case ISD::VSCALE: {
765
778
MVT VT = Op.getSimpleValueType ();
766
779
SDLoc DL (Op);
@@ -1209,6 +1222,12 @@ SDValue RISCVTargetLowering::lowerSPLATVECTOR(SDValue Op,
1209
1222
DAG.getConstant (CVal->getSExtValue (), DL, MVT::i32 ));
1210
1223
}
1211
1224
1225
+ if (SplatVal.getOpcode () == ISD::SIGN_EXTEND &&
1226
+ SplatVal.getOperand (0 ).getValueType () == MVT::i32 ) {
1227
+ return DAG.getNode (RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
1228
+ SplatVal.getOperand (0 ));
1229
+ }
1230
+
1212
1231
// Else, on RV32 we lower an i64-element SPLAT_VECTOR thus, being careful not
1213
1232
// to accidentally sign-extend the 32-bit halves to the e64 SEW:
1214
1233
// vmv.v.x vX, hi
@@ -1306,6 +1325,72 @@ SDValue RISCVTargetLowering::lowerVectorMaskTrunc(SDValue Op,
1306
1325
return DAG.getSetCC (DL, MaskVT, Trunc, SplatZero, ISD::SETNE);
1307
1326
}
1308
1327
1328
+ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT (SDValue Op,
1329
+ SelectionDAG &DAG) const {
1330
+ SDLoc DL (Op);
1331
+ EVT VecVT = Op.getValueType ();
1332
+ SDValue Vec = Op.getOperand (0 );
1333
+ SDValue Val = Op.getOperand (1 );
1334
+ SDValue Idx = Op.getOperand (2 );
1335
+
1336
+ // Custom-legalize INSERT_VECTOR_ELT where XLEN>=SEW, so that the vector is
1337
+ // first slid down into position, the value is inserted into the first
1338
+ // position, and the vector is slid back up. We do this to simplify patterns.
1339
+ // (slideup vec, (insertelt (slidedown impdef, vec, idx), val, 0), idx),
1340
+ if (Subtarget.is64Bit () || VecVT.getVectorElementType () != MVT::i64 ) {
1341
+ if (isNullConstant (Idx))
1342
+ return Op;
1343
+ SDValue Slidedown = DAG.getNode (RISCVISD::VSLIDEDOWN, DL, VecVT,
1344
+ DAG.getUNDEF (VecVT), Vec, Idx);
1345
+ SDValue InsertElt0 =
1346
+ DAG.getNode (ISD::INSERT_VECTOR_ELT, DL, VecVT, Slidedown, Val,
1347
+ DAG.getConstant (0 , DL, Subtarget.getXLenVT ()));
1348
+
1349
+ return DAG.getNode (RISCVISD::VSLIDEUP, DL, VecVT, Vec, InsertElt0, Idx);
1350
+ }
1351
+
1352
+ // Custom-legalize INSERT_VECTOR_ELT where XLEN<SEW, as the SEW element type
1353
+ // is illegal (currently only vXi64 RV32).
1354
+ // Since there is no easy way of getting a single element into a vector when
1355
+ // XLEN<SEW, we lower the operation to the following sequence:
1356
+ // splat vVal, rVal
1357
+ // vid.v vVid
1358
+ // vmseq.vx mMask, vVid, rIdx
1359
+ // vmerge.vvm vDest, vSrc, vVal, mMask
1360
+ // This essentially merges the original vector with the inserted element by
1361
+ // using a mask whose only set bit is that corresponding to the insert
1362
+ // index.
1363
+ SDValue SplattedVal = DAG.getSplatVector (VecVT, DL, Val);
1364
+ SDValue SplattedIdx = DAG.getNode (RISCVISD::SPLAT_VECTOR_I64, DL, VecVT, Idx);
1365
+
1366
+ SDValue VID = DAG.getNode (RISCVISD::VID, DL, VecVT);
1367
+ auto SetCCVT =
1368
+ getSetCCResultType (DAG.getDataLayout (), *DAG.getContext (), VecVT);
1369
+ SDValue Mask = DAG.getSetCC (DL, SetCCVT, VID, SplattedIdx, ISD::SETEQ);
1370
+
1371
+ return DAG.getNode (ISD::VSELECT, DL, VecVT, Mask, SplattedVal, Vec);
1372
+ }
1373
+
1374
+ // Custom-lower EXTRACT_VECTOR_ELT operations to slide the vector down, then
1375
+ // extract the first element: (extractelt (slidedown vec, idx), 0). This is
1376
+ // done to maintain partity with the legalization of RV32 vXi64 legalization.
1377
+ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT (SDValue Op,
1378
+ SelectionDAG &DAG) const {
1379
+ SDLoc DL (Op);
1380
+ SDValue Idx = Op.getOperand (1 );
1381
+ if (isNullConstant (Idx))
1382
+ return Op;
1383
+
1384
+ SDValue Vec = Op.getOperand (0 );
1385
+ EVT EltVT = Op.getValueType ();
1386
+ EVT VecVT = Vec.getValueType ();
1387
+ SDValue Slidedown = DAG.getNode (RISCVISD::VSLIDEDOWN, DL, VecVT,
1388
+ DAG.getUNDEF (VecVT), Vec, Idx);
1389
+
1390
+ return DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Slidedown,
1391
+ DAG.getConstant (0 , DL, Subtarget.getXLenVT ()));
1392
+ }
1393
+
1309
1394
SDValue RISCVTargetLowering::LowerINTRINSIC_WO_CHAIN (SDValue Op,
1310
1395
SelectionDAG &DAG) const {
1311
1396
unsigned IntNo = cast<ConstantSDNode>(Op.getOperand (0 ))->getZExtValue ();
@@ -1640,6 +1725,44 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
1640
1725
Results.push_back (DAG.getNode (ISD::TRUNCATE, DL, MVT::i32 , NewOp));
1641
1726
break ;
1642
1727
}
1728
+ case ISD::EXTRACT_VECTOR_ELT: {
1729
+ // Custom-legalize an EXTRACT_VECTOR_ELT where XLEN<SEW, as the SEW element
1730
+ // type is illegal (currently only vXi64 RV32).
1731
+ // With vmv.x.s, when SEW > XLEN, only the least-significant XLEN bits are
1732
+ // transferred to the destination register. We issue two of these from the
1733
+ // upper- and lower- halves of the SEW-bit vector element, slid down to the
1734
+ // first element.
1735
+ SDLoc DL (N);
1736
+ SDValue Vec = N->getOperand (0 );
1737
+ SDValue Idx = N->getOperand (1 );
1738
+ EVT VecVT = Vec.getValueType ();
1739
+ assert (!Subtarget.is64Bit () && N->getValueType (0 ) == MVT::i64 &&
1740
+ VecVT.getVectorElementType () == MVT::i64 &&
1741
+ " Unexpected EXTRACT_VECTOR_ELT legalization" );
1742
+
1743
+ SDValue Slidedown = Vec;
1744
+ // Unless the index is known to be 0, we must slide the vector down to get
1745
+ // the desired element into index 0.
1746
+ if (!isNullConstant (Idx))
1747
+ Slidedown = DAG.getNode (RISCVISD::VSLIDEDOWN, DL, VecVT,
1748
+ DAG.getUNDEF (VecVT), Vec, Idx);
1749
+
1750
+ MVT XLenVT = Subtarget.getXLenVT ();
1751
+ // Extract the lower XLEN bits of the correct vector element.
1752
+ SDValue EltLo = DAG.getNode (RISCVISD::VMV_X_S, DL, XLenVT, Slidedown, Idx);
1753
+
1754
+ // To extract the upper XLEN bits of the vector element, shift the first
1755
+ // element right by 32 bits and re-extract the lower XLEN bits.
1756
+ SDValue ThirtyTwoV =
1757
+ DAG.getNode (RISCVISD::SPLAT_VECTOR_I64, DL, VecVT,
1758
+ DAG.getConstant (32 , DL, Subtarget.getXLenVT ()));
1759
+ SDValue LShr32 = DAG.getNode (ISD::SRL, DL, VecVT, Slidedown, ThirtyTwoV);
1760
+
1761
+ SDValue EltHi = DAG.getNode (RISCVISD::VMV_X_S, DL, XLenVT, LShr32, Idx);
1762
+
1763
+ Results.push_back (DAG.getNode (ISD::BUILD_PAIR, DL, MVT::i64 , EltLo, EltHi));
1764
+ break ;
1765
+ }
1643
1766
case ISD::INTRINSIC_WO_CHAIN: {
1644
1767
unsigned IntNo = cast<ConstantSDNode>(N->getOperand (0 ))->getZExtValue ();
1645
1768
switch (IntNo) {
@@ -2231,8 +2354,12 @@ unsigned RISCVTargetLowering::ComputeNumSignBitsForTargetNode(
2231
2354
return 33 ;
2232
2355
case RISCVISD::VMV_X_S:
2233
2356
// The number of sign bits of the scalar result is computed by obtaining the
2234
- // element type of the input vector operand, substracting its width from the
2235
- // XLEN, and then adding one (sign bit within the element type).
2357
+ // element type of the input vector operand, subtracting its width from the
2358
+ // XLEN, and then adding one (sign bit within the element type). If the
2359
+ // element type is wider than XLen, the least-significant XLEN bits are
2360
+ // taken.
2361
+ if (Op.getOperand (0 ).getScalarValueSizeInBits () > Subtarget.getXLen ())
2362
+ return 1 ;
2236
2363
return Subtarget.getXLen () - Op.getOperand (0 ).getScalarValueSizeInBits () + 1 ;
2237
2364
}
2238
2365
@@ -3893,6 +4020,9 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
3893
4020
NODE_NAME_CASE (VLEFF)
3894
4021
NODE_NAME_CASE (VLEFF_MASK)
3895
4022
NODE_NAME_CASE (READ_VL)
4023
+ NODE_NAME_CASE (VSLIDEUP)
4024
+ NODE_NAME_CASE (VSLIDEDOWN)
4025
+ NODE_NAME_CASE (VID)
3896
4026
}
3897
4027
// clang-format on
3898
4028
return nullptr ;
0 commit comments