@@ -1526,6 +1526,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
1526
1526
setOperationAction(ISD::FNEARBYINT, VT, Custom);
1527
1527
setOperationAction(ISD::FRINT, VT, Custom);
1528
1528
setOperationAction(ISD::FROUND, VT, Custom);
1529
+ setOperationAction(ISD::LRINT, VT, Custom);
1530
+ setOperationAction(ISD::LLRINT, VT, Custom);
1529
1531
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
1530
1532
setOperationAction(ISD::FTRUNC, VT, Custom);
1531
1533
setOperationAction(ISD::FSQRT, VT, Custom);
@@ -1940,6 +1942,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
1940
1942
setOperationAction(ISD::FP_TO_UINT, VT, Default);
1941
1943
setOperationAction(ISD::FRINT, VT, Default);
1942
1944
setOperationAction(ISD::FROUND, VT, Default);
1945
+ setOperationAction(ISD::LRINT, VT, Default);
1946
+ setOperationAction(ISD::LLRINT, VT, Default);
1943
1947
setOperationAction(ISD::FROUNDEVEN, VT, Default);
1944
1948
setOperationAction(ISD::FSQRT, VT, Default);
1945
1949
setOperationAction(ISD::FSUB, VT, Default);
@@ -4362,6 +4366,59 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
4362
4366
return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
4363
4367
}
4364
4368
4369
+ SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4370
+ SelectionDAG &DAG) const {
4371
+ EVT VT = Op.getValueType();
4372
+ SDValue Src = Op.getOperand(0);
4373
+ SDLoc DL(Op);
4374
+
4375
+ assert(VT.isVector() && "Expected vector type");
4376
+
4377
+ // We can't custom-lower ISD::[L]LRINT without SVE, since it requires
4378
+ // AArch64ISD::FCVTZS_MERGE_PASSTHRU.
4379
+ if (!Subtarget->isSVEAvailable())
4380
+ return SDValue();
4381
+
4382
+ EVT ContainerVT = VT;
4383
+ EVT SrcVT = Src.getValueType();
4384
+ EVT CastVT =
4385
+ ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4386
+
4387
+ if (VT.isFixedLengthVector()) {
4388
+ ContainerVT = getContainerForFixedLengthVector(DAG, VT);
4389
+ CastVT = ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4390
+ Src = convertToScalableVector(DAG, CastVT, Src);
4391
+ }
4392
+
4393
+ // First, round the floating-point value into a floating-point register with
4394
+ // the current rounding mode.
4395
+ SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4396
+
4397
+ // In the case of vector filled with f32, ftrunc will convert it to an i32,
4398
+ // but a vector filled with i32 isn't legal. So, FP_EXTEND the f32 into the
4399
+ // required size.
4400
+ size_t SrcSz = SrcVT.getScalarSizeInBits();
4401
+ size_t ContainerSz = ContainerVT.getScalarSizeInBits();
4402
+ if (ContainerSz > SrcSz) {
4403
+ EVT SizedVT = MVT::getVectorVT(MVT::getFloatingPointVT(ContainerSz),
4404
+ ContainerVT.getVectorElementCount());
4405
+ FOp = DAG.getNode(ISD::FP_EXTEND, DL, SizedVT, FOp.getOperand(0));
4406
+ }
4407
+
4408
+ // Finally, truncate the rounded floating point to an integer, rounding to
4409
+ // zero.
4410
+ SDValue Pred = getPredicateForVector(DAG, DL, ContainerVT);
4411
+ SDValue Undef = DAG.getUNDEF(ContainerVT);
4412
+ SDValue Truncated =
4413
+ DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, DL, ContainerVT,
4414
+ {Pred, FOp.getOperand(0), Undef}, FOp->getFlags());
4415
+
4416
+ if (VT.isScalableVector())
4417
+ return Truncated;
4418
+
4419
+ return convertFromScalableVector(DAG, VT, Truncated);
4420
+ }
4421
+
4365
4422
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
4366
4423
SelectionDAG &DAG) const {
4367
4424
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -6685,10 +6742,13 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
6685
6742
return LowerVECTOR_DEINTERLEAVE(Op, DAG);
6686
6743
case ISD::VECTOR_INTERLEAVE:
6687
6744
return LowerVECTOR_INTERLEAVE(Op, DAG);
6688
- case ISD::LROUND:
6689
- case ISD::LLROUND:
6690
6745
case ISD::LRINT:
6691
- case ISD::LLRINT: {
6746
+ case ISD::LLRINT:
6747
+ if (Op.getValueType().isVector())
6748
+ return LowerVectorXRINT(Op, DAG);
6749
+ [[fallthrough]];
6750
+ case ISD::LROUND:
6751
+ case ISD::LLROUND: {
6692
6752
assert((Op.getOperand(0).getValueType() == MVT::f16 ||
6693
6753
Op.getOperand(0).getValueType() == MVT::bf16) &&
6694
6754
"Expected custom lowering of rounding operations only for f16");
0 commit comments