Skip to content

Commit f19a306

Browse files
committed
ISel/AArch64/SVE: custom lower vector ISD::[L]LRINT
Since 98c90a1 (ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering), ISD::LRINT and ISD::LLRINT now have vector variants, that are custom lowered on RISCV, and scalarized on all other targets. Since 2302e4c (Reland "VectorUtils: mark xrint as trivially vectorizable"), lrint and llrint are trivially vectorizable, so all the vectorizers in-tree will produce vector variants when possible. Add a custom lowering for AArch64 to custom-lower the vector variants natively using a combination of frintx, fcvte, and fcvtzs, when SVE is present.
1 parent de9b386 commit f19a306

File tree

8 files changed

+3512
-32
lines changed

8 files changed

+3512
-32
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 63 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1526,6 +1526,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15261526
setOperationAction(ISD::FNEARBYINT, VT, Custom);
15271527
setOperationAction(ISD::FRINT, VT, Custom);
15281528
setOperationAction(ISD::FROUND, VT, Custom);
1529+
setOperationAction(ISD::LRINT, VT, Custom);
1530+
setOperationAction(ISD::LLRINT, VT, Custom);
15291531
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
15301532
setOperationAction(ISD::FTRUNC, VT, Custom);
15311533
setOperationAction(ISD::FSQRT, VT, Custom);
@@ -1940,6 +1942,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
19401942
setOperationAction(ISD::FP_TO_UINT, VT, Default);
19411943
setOperationAction(ISD::FRINT, VT, Default);
19421944
setOperationAction(ISD::FROUND, VT, Default);
1945+
setOperationAction(ISD::LRINT, VT, Default);
1946+
setOperationAction(ISD::LLRINT, VT, Default);
19431947
setOperationAction(ISD::FROUNDEVEN, VT, Default);
19441948
setOperationAction(ISD::FSQRT, VT, Default);
19451949
setOperationAction(ISD::FSUB, VT, Default);
@@ -4362,6 +4366,59 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
43624366
return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
43634367
}
43644368

4369+
SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4370+
SelectionDAG &DAG) const {
4371+
EVT VT = Op.getValueType();
4372+
SDValue Src = Op.getOperand(0);
4373+
SDLoc DL(Op);
4374+
4375+
assert(VT.isVector() && "Expected vector type");
4376+
4377+
// We can't custom-lower ISD::[L]LRINT without SVE, since it requires
4378+
// AArch64ISD::FCVTZS_MERGE_PASSTHRU.
4379+
if (!Subtarget->isSVEAvailable())
4380+
return SDValue();
4381+
4382+
EVT ContainerVT = VT;
4383+
EVT SrcVT = Src.getValueType();
4384+
EVT CastVT =
4385+
ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4386+
4387+
if (VT.isFixedLengthVector()) {
4388+
ContainerVT = getContainerForFixedLengthVector(DAG, VT);
4389+
CastVT = ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
4390+
Src = convertToScalableVector(DAG, CastVT, Src);
4391+
}
4392+
4393+
// First, round the floating-point value into a floating-point register with
4394+
// the current rounding mode.
4395+
SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4396+
4397+
// In the case of vector filled with f32, ftrunc will convert it to an i32,
4398+
// but a vector filled with i32 isn't legal. So, FP_EXTEND the f32 into the
4399+
// required size.
4400+
size_t SrcSz = SrcVT.getScalarSizeInBits();
4401+
size_t ContainerSz = ContainerVT.getScalarSizeInBits();
4402+
if (ContainerSz > SrcSz) {
4403+
EVT SizedVT = MVT::getVectorVT(MVT::getFloatingPointVT(ContainerSz),
4404+
ContainerVT.getVectorElementCount());
4405+
FOp = DAG.getNode(ISD::FP_EXTEND, DL, SizedVT, FOp.getOperand(0));
4406+
}
4407+
4408+
// Finally, truncate the rounded floating point to an integer, rounding to
4409+
// zero.
4410+
SDValue Pred = getPredicateForVector(DAG, DL, ContainerVT);
4411+
SDValue Undef = DAG.getUNDEF(ContainerVT);
4412+
SDValue Truncated =
4413+
DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, DL, ContainerVT,
4414+
{Pred, FOp.getOperand(0), Undef}, FOp->getFlags());
4415+
4416+
if (VT.isScalableVector())
4417+
return Truncated;
4418+
4419+
return convertFromScalableVector(DAG, VT, Truncated);
4420+
}
4421+
43654422
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
43664423
SelectionDAG &DAG) const {
43674424
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -6685,10 +6742,13 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
66856742
return LowerVECTOR_DEINTERLEAVE(Op, DAG);
66866743
case ISD::VECTOR_INTERLEAVE:
66876744
return LowerVECTOR_INTERLEAVE(Op, DAG);
6688-
case ISD::LROUND:
6689-
case ISD::LLROUND:
66906745
case ISD::LRINT:
6691-
case ISD::LLRINT: {
6746+
case ISD::LLRINT:
6747+
if (Op.getValueType().isVector())
6748+
return LowerVectorXRINT(Op, DAG);
6749+
[[fallthrough]];
6750+
case ISD::LROUND:
6751+
case ISD::LLROUND: {
66926752
assert((Op.getOperand(0).getValueType() == MVT::f16 ||
66936753
Op.getOperand(0).getValueType() == MVT::bf16) &&
66946754
"Expected custom lowering of rounding operations only for f16");

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,7 @@ class AArch64TargetLowering : public TargetLowering {
11651165
SDValue LowerVectorFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
11661166
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
11671167
SDValue LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
1168+
SDValue LowerVectorXRINT(SDValue Op, SelectionDAG &DAG) const;
11681169
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
11691170
SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
11701171
SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)