@@ -1304,6 +1304,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1304
1304
setOperationAction(Op, Ty, Legal);
1305
1305
}
1306
1306
1307
+ // LRINT and LLRINT.
1308
+ for (auto Op : {ISD::LRINT, ISD::LLRINT}) {
1309
+ for (MVT Ty : {MVT::v2f32, MVT::v4f32, MVT::v2f64})
1310
+ setOperationAction(Op, Ty, Custom);
1311
+ if (Subtarget->hasFullFP16())
1312
+ for (MVT Ty : {MVT::v4f16, MVT::v8f16})
1313
+ setOperationAction(Op, Ty, Custom);
1314
+ }
1315
+
1307
1316
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom);
1308
1317
1309
1318
setOperationAction(ISD::BITCAST, MVT::i2, Custom);
@@ -1525,6 +1534,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1525
1534
setOperationAction(ISD::FFLOOR, VT, Custom);
1526
1535
setOperationAction(ISD::FNEARBYINT, VT, Custom);
1527
1536
setOperationAction(ISD::FRINT, VT, Custom);
1537
+ setOperationAction(ISD::LRINT, VT, Custom);
1538
+ setOperationAction(ISD::LLRINT, VT, Custom);
1528
1539
setOperationAction(ISD::FROUND, VT, Custom);
1529
1540
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1530
1541
setOperationAction(ISD::FTRUNC, VT, Custom);
@@ -1666,7 +1677,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1666
1677
setOperationAction(ISD::MULHU, VT, Custom);
1667
1678
}
1668
1679
1669
-
1670
1680
// Use SVE for vectors with more than 2 elements.
1671
1681
for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
1672
1682
setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
@@ -1940,6 +1950,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
1940
1950
setOperationAction(ISD::FP_TO_SINT, VT, Default);
1941
1951
setOperationAction(ISD::FP_TO_UINT, VT, Default);
1942
1952
setOperationAction(ISD::FRINT, VT, Default);
1953
+ setOperationAction(ISD::LRINT, VT, Default);
1954
+ setOperationAction(ISD::LLRINT, VT, Default);
1943
1955
setOperationAction(ISD::FROUND, VT, Default);
1944
1956
setOperationAction(ISD::FROUNDEVEN, VT, Default);
1945
1957
setOperationAction(ISD::FSQRT, VT, Default);
@@ -4363,6 +4375,26 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
4363
4375
return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
4364
4376
}
4365
4377
4378
+ SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4379
+ SelectionDAG &DAG) const {
4380
+ EVT VT = Op.getValueType();
4381
+ SDValue Src = Op.getOperand(0);
4382
+ SDLoc DL(Op);
4383
+
4384
+ assert(VT.isVector() && "Expected vector type");
4385
+
4386
+ EVT CastVT =
4387
+ VT.changeVectorElementType(Src.getValueType().getVectorElementType());
4388
+
4389
+ // Round the floating-point value into a floating-point register with the
4390
+ // current rounding mode.
4391
+ SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4392
+
4393
+ // Truncate the rounded floating point to an integer.
4394
+ return DAG.getNode(ISD::FP_TO_SINT_SAT, DL, VT, FOp,
4395
+ DAG.getValueType(VT.getVectorElementType()));
4396
+ }
4397
+
4366
4398
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
4367
4399
SelectionDAG &DAG) const {
4368
4400
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -6686,10 +6718,13 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
6686
6718
return LowerVECTOR_DEINTERLEAVE(Op, DAG);
6687
6719
case ISD::VECTOR_INTERLEAVE:
6688
6720
return LowerVECTOR_INTERLEAVE(Op, DAG);
6689
- case ISD::LROUND:
6690
- case ISD::LLROUND:
6691
6721
case ISD::LRINT:
6692
- case ISD::LLRINT: {
6722
+ case ISD::LLRINT:
6723
+ if (Op.getValueType().isVector())
6724
+ return LowerVectorXRINT(Op, DAG);
6725
+ [[fallthrough]];
6726
+ case ISD::LROUND:
6727
+ case ISD::LLROUND: {
6693
6728
assert((Op.getOperand(0).getValueType() == MVT::f16 ||
6694
6729
Op.getOperand(0).getValueType() == MVT::bf16) &&
6695
6730
"Expected custom lowering of rounding operations only for f16");
0 commit comments