Skip to content

Commit 91feb13

Browse files
authored
ISel/AArch64: custom lower vector ISD::[L]LRINT (#89035)
Since 98c90a1 (ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering), ISD::LRINT and ISD::LLRINT now have vector variants, that are custom lowered on RISCV, and scalarized on all other targets. Since 2302e4c (Reland "VectorUtils: mark xrint as trivially vectorizable"), lrint and llrint are trivially vectorizable, so all the vectorizers in-tree will produce vector variants when possible. Add a custom lowering for AArch64 to custom-lower the vector variants natively using a combination of frintx, fcvte, and fcvtzs.
1 parent d2d08ea commit 91feb13

File tree

8 files changed

+7740
-759
lines changed

8 files changed

+7740
-759
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1304,6 +1304,15 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
13041304
setOperationAction(Op, Ty, Legal);
13051305
}
13061306

1307+
// LRINT and LLRINT.
1308+
for (auto Op : {ISD::LRINT, ISD::LLRINT}) {
1309+
for (MVT Ty : {MVT::v2f32, MVT::v4f32, MVT::v2f64})
1310+
setOperationAction(Op, Ty, Custom);
1311+
if (Subtarget->hasFullFP16())
1312+
for (MVT Ty : {MVT::v4f16, MVT::v8f16})
1313+
setOperationAction(Op, Ty, Custom);
1314+
}
1315+
13071316
setTruncStoreAction(MVT::v4i16, MVT::v4i8, Custom);
13081317

13091318
setOperationAction(ISD::BITCAST, MVT::i2, Custom);
@@ -1525,6 +1534,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
15251534
setOperationAction(ISD::FFLOOR, VT, Custom);
15261535
setOperationAction(ISD::FNEARBYINT, VT, Custom);
15271536
setOperationAction(ISD::FRINT, VT, Custom);
1537+
setOperationAction(ISD::LRINT, VT, Custom);
1538+
setOperationAction(ISD::LLRINT, VT, Custom);
15281539
setOperationAction(ISD::FROUND, VT, Custom);
15291540
setOperationAction(ISD::FROUNDEVEN, VT, Custom);
15301541
setOperationAction(ISD::FTRUNC, VT, Custom);
@@ -1666,7 +1677,6 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
16661677
setOperationAction(ISD::MULHU, VT, Custom);
16671678
}
16681679

1669-
16701680
// Use SVE for vectors with more than 2 elements.
16711681
for (auto VT : {MVT::v4f16, MVT::v8f16, MVT::v4f32})
16721682
setOperationAction(ISD::VECREDUCE_FADD, VT, Custom);
@@ -1940,6 +1950,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT) {
19401950
setOperationAction(ISD::FP_TO_SINT, VT, Default);
19411951
setOperationAction(ISD::FP_TO_UINT, VT, Default);
19421952
setOperationAction(ISD::FRINT, VT, Default);
1953+
setOperationAction(ISD::LRINT, VT, Default);
1954+
setOperationAction(ISD::LLRINT, VT, Default);
19431955
setOperationAction(ISD::FROUND, VT, Default);
19441956
setOperationAction(ISD::FROUNDEVEN, VT, Default);
19451957
setOperationAction(ISD::FSQRT, VT, Default);
@@ -4363,6 +4375,26 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
43634375
return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
43644376
}
43654377

4378+
SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
4379+
SelectionDAG &DAG) const {
4380+
EVT VT = Op.getValueType();
4381+
SDValue Src = Op.getOperand(0);
4382+
SDLoc DL(Op);
4383+
4384+
assert(VT.isVector() && "Expected vector type");
4385+
4386+
EVT CastVT =
4387+
VT.changeVectorElementType(Src.getValueType().getVectorElementType());
4388+
4389+
// Round the floating-point value into a floating-point register with the
4390+
// current rounding mode.
4391+
SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
4392+
4393+
// Truncate the rounded floating point to an integer.
4394+
return DAG.getNode(ISD::FP_TO_SINT_SAT, DL, VT, FOp,
4395+
DAG.getValueType(VT.getVectorElementType()));
4396+
}
4397+
43664398
SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
43674399
SelectionDAG &DAG) const {
43684400
// Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -6686,10 +6718,13 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
66866718
return LowerVECTOR_DEINTERLEAVE(Op, DAG);
66876719
case ISD::VECTOR_INTERLEAVE:
66886720
return LowerVECTOR_INTERLEAVE(Op, DAG);
6689-
case ISD::LROUND:
6690-
case ISD::LLROUND:
66916721
case ISD::LRINT:
6692-
case ISD::LLRINT: {
6722+
case ISD::LLRINT:
6723+
if (Op.getValueType().isVector())
6724+
return LowerVectorXRINT(Op, DAG);
6725+
[[fallthrough]];
6726+
case ISD::LROUND:
6727+
case ISD::LLROUND: {
66936728
assert((Op.getOperand(0).getValueType() == MVT::f16 ||
66946729
Op.getOperand(0).getValueType() == MVT::bf16) &&
66956730
"Expected custom lowering of rounding operations only for f16");

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1165,6 +1165,7 @@ class AArch64TargetLowering : public TargetLowering {
11651165
SDValue LowerVectorFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
11661166
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
11671167
SDValue LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
1168+
SDValue LowerVectorXRINT(SDValue Op, SelectionDAG &DAG) const;
11681169
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
11691170
SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
11701171
SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)