Skip to content

ISel/AArch64: custom lower vector ISD::[L]LRINT #89035

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 10, 2024

Conversation

artagnon
Copy link
Contributor

@artagnon artagnon commented Apr 17, 2024

Since 98c90a1 (ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering), ISD::LRINT and ISD::LLRINT now have vector variants, that are custom lowered on RISCV, and scalarized on all other targets. Since 2302e4c (Reland "VectorUtils: mark xrint as trivially vectorizable"), lrint and llrint are trivially vectorizable, so all the vectorizers in-tree will produce vector variants when possible. Add a custom lowering for AArch64 to custom-lower the vector variants natively using a combination of frintx, fcvte, and fcvtzs.

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2024

@llvm/pr-subscribers-backend-aarch64

Author: Ramkumar Ramachandra (artagnon)

Changes

Since 98c90a1 (ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom RISCV lowering), ISD::LRINT and ISD::LLRINT now have vector variants, that are custom lowered on RISCV, and scalarized on all other targets. Since 2302e4c (Reland "VectorUtils: mark xrint as trivially vectorizable"), lrint and llrint are trivially vectorizable, so all the vectorizers in-tree will produce vector variants when possible. Add a custom lowering for AArch64 to custom-lower the vector variants natively using a combination of frintx, fcvte, and fcvtzs.


Patch is 140.93 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89035.diff

6 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.cpp (+68-9)
  • (modified) llvm/lib/Target/AArch64/AArch64ISelLowering.h (+1)
  • (added) llvm/test/CodeGen/AArch64/fixed-vector-llrint.ll (+733)
  • (added) llvm/test/CodeGen/AArch64/fixed-vector-lrint.ll (+747)
  • (modified) llvm/test/CodeGen/AArch64/vector-llrint.ll (+416-545)
  • (modified) llvm/test/CodeGen/AArch64/vector-lrint.ll (+419-555)
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
index 80181a77c9d238..29d8ac65a7566c 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
@@ -790,7 +790,7 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     setOperationAction(ISD::FROUND,      V8Narrow, Legal);
     setOperationAction(ISD::FROUNDEVEN,  V8Narrow, Legal);
     setOperationAction(ISD::FRINT,       V8Narrow, Legal);
-    setOperationAction(ISD::FSQRT,       V8Narrow, Expand);
+    setOperationAction(ISD::FSQRT, V8Narrow, Expand);
     setOperationAction(ISD::FSUB,        V8Narrow, Legal);
     setOperationAction(ISD::FTRUNC,      V8Narrow, Legal);
     setOperationAction(ISD::SETCC,       V8Narrow, Expand);
@@ -1147,8 +1147,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
 
     for (auto Op :
          {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP, ISD::UINT_TO_FP,
-          ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::MUL,
-          ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
+          ISD::FP_ROUND, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT, ISD::LRINT,
+          ISD::LLRINT, ISD::MUL, ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT,
           ISD::STRICT_SINT_TO_FP, ISD::STRICT_UINT_TO_FP, ISD::STRICT_FP_ROUND})
       setOperationAction(Op, MVT::v1i64, Expand);
 
@@ -1355,6 +1355,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::SINT_TO_FP, VT, Custom);
       setOperationAction(ISD::FP_TO_UINT, VT, Custom);
       setOperationAction(ISD::FP_TO_SINT, VT, Custom);
+      setOperationAction(ISD::LRINT, VT, Custom);
+      setOperationAction(ISD::LLRINT, VT, Custom);
       setOperationAction(ISD::MGATHER, VT, Custom);
       setOperationAction(ISD::MSCATTER, VT, Custom);
       setOperationAction(ISD::MLOAD, VT, Custom);
@@ -1420,6 +1422,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
     for (auto VT : {MVT::nxv8i8, MVT::nxv4i16, MVT::nxv2i32}) {
       setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
       setOperationAction(ISD::INSERT_SUBVECTOR, VT, Custom);
+      setOperationAction(ISD::LRINT, VT, Custom);
+      setOperationAction(ISD::LLRINT, VT, Custom);
     }
 
     // Legalize unpacked bitcasts to REINTERPRET_CAST.
@@ -1522,6 +1526,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
       setOperationAction(ISD::FFLOOR, VT, Custom);
       setOperationAction(ISD::FNEARBYINT, VT, Custom);
       setOperationAction(ISD::FRINT, VT, Custom);
+      setOperationAction(ISD::LRINT, VT, Custom);
+      setOperationAction(ISD::LLRINT, VT, Custom);
       setOperationAction(ISD::FROUND, VT, Custom);
       setOperationAction(ISD::FROUNDEVEN, VT, Custom);
       setOperationAction(ISD::FTRUNC, VT, Custom);
@@ -1785,9 +1791,9 @@ void AArch64TargetLowering::addTypeForNEON(MVT VT) {
   setOperationAction(ISD::SREM, VT, Expand);
   setOperationAction(ISD::FREM, VT, Expand);
 
-  for (unsigned Opcode :
-       {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
-        ISD::FP_TO_UINT_SAT, ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
+  for (unsigned Opcode : {ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::FP_TO_SINT_SAT,
+                          ISD::FP_TO_UINT_SAT, ISD::LRINT, ISD::LLRINT,
+                          ISD::STRICT_FP_TO_SINT, ISD::STRICT_FP_TO_UINT})
     setOperationAction(Opcode, VT, Custom);
 
   if (!VT.isFloatingPoint())
@@ -1947,6 +1953,8 @@ void AArch64TargetLowering::addTypeForFixedLengthSVE(MVT VT,
   setOperationAction(ISD::FP_TO_SINT, VT, Custom);
   setOperationAction(ISD::FP_TO_UINT, VT, Custom);
   setOperationAction(ISD::FRINT, VT, Custom);
+  setOperationAction(ISD::LRINT, VT, Custom);
+  setOperationAction(ISD::LLRINT, VT, Custom);
   setOperationAction(ISD::FROUND, VT, Custom);
   setOperationAction(ISD::FROUNDEVEN, VT, Custom);
   setOperationAction(ISD::FSQRT, VT, Custom);
@@ -4371,6 +4379,54 @@ SDValue AArch64TargetLowering::LowerFP_TO_INT_SAT(SDValue Op,
   return DAG.getNode(ISD::TRUNCATE, DL, DstVT, Sat);
 }
 
+SDValue AArch64TargetLowering::LowerVectorXRINT(SDValue Op,
+                                                SelectionDAG &DAG) const {
+  EVT VT = Op.getValueType();
+  SDValue Src = Op.getOperand(0);
+  SDLoc DL(Op);
+
+  assert(VT.isVector() && "Expected vector type");
+
+  EVT ContainerVT = VT;
+  EVT SrcVT = Src.getValueType();
+  EVT CastVT =
+      ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
+
+  if (VT.isFixedLengthVector()) {
+    ContainerVT = getContainerForFixedLengthVector(DAG, VT);
+    CastVT = ContainerVT.changeVectorElementType(SrcVT.getVectorElementType());
+    Src = convertToScalableVector(DAG, CastVT, Src);
+  }
+
+  // First, round the floating-point value into a floating-point register with
+  // the current rounding mode.
+  SDValue FOp = DAG.getNode(ISD::FRINT, DL, CastVT, Src);
+
+  // In the case of vector filled with f32, ftrunc will convert it to an i32,
+  // but a vector filled with i32 isn't legal. So, FP_EXTEND the f32 into the
+  // required size.
+  size_t SrcSz = SrcVT.getScalarSizeInBits();
+  size_t ContainerSz = ContainerVT.getScalarSizeInBits();
+  if (ContainerSz > SrcSz) {
+    EVT WidenedVT = MVT::getVectorVT(MVT::getFloatingPointVT(ContainerSz),
+                                     ContainerVT.getVectorElementCount());
+    FOp = DAG.getNode(ISD::FP_EXTEND, DL, WidenedVT, FOp.getOperand(0));
+  }
+
+  // Finally, truncate the rounded floating point to an integer, rounding to
+  // zero.
+  SDValue Pred = getPredicateForVector(DAG, DL, ContainerVT);
+  SDValue Undef = DAG.getUNDEF(ContainerVT);
+  SDValue Truncated =
+      DAG.getNode(AArch64ISD::FCVTZS_MERGE_PASSTHRU, DL, ContainerVT,
+                  {Pred, FOp.getOperand(0), Undef}, FOp->getFlags());
+
+  if (!VT.isFixedLengthVector())
+    return Truncated;
+
+  return convertFromScalableVector(DAG, VT, Truncated);
+}
+
 SDValue AArch64TargetLowering::LowerVectorINT_TO_FP(SDValue Op,
                                                     SelectionDAG &DAG) const {
   // Warning: We maintain cost tables in AArch64TargetTransformInfo.cpp.
@@ -6628,10 +6684,13 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
     return LowerVECTOR_DEINTERLEAVE(Op, DAG);
   case ISD::VECTOR_INTERLEAVE:
     return LowerVECTOR_INTERLEAVE(Op, DAG);
-  case ISD::LROUND:
-  case ISD::LLROUND:
   case ISD::LRINT:
-  case ISD::LLRINT: {
+  case ISD::LLRINT:
+    if (Op.getValueType().isVector())
+      return LowerVectorXRINT(Op, DAG);
+    [[fallthrough]];
+  case ISD::LROUND:
+  case ISD::LLROUND: {
     assert((Op.getOperand(0).getValueType() == MVT::f16 ||
             Op.getOperand(0).getValueType() == MVT::bf16) &&
            "Expected custom lowering of rounding operations only for f16");
diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
index 18439dc7f01020..65277a09320705 100644
--- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h
+++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h
@@ -1155,6 +1155,7 @@ class AArch64TargetLowering : public TargetLowering {
   SDValue LowerVectorFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG) const;
+  SDValue LowerVectorXRINT(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVectorINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
   SDValue LowerVectorOR(SDValue Op, SelectionDAG &DAG) const;
diff --git a/llvm/test/CodeGen/AArch64/fixed-vector-llrint.ll b/llvm/test/CodeGen/AArch64/fixed-vector-llrint.ll
new file mode 100644
index 00000000000000..772d767380a848
--- /dev/null
+++ b/llvm/test/CodeGen/AArch64/fixed-vector-llrint.ll
@@ -0,0 +1,733 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=aarch64 -mattr=+sve | FileCheck %s
+
+define <1 x i64> @llrint_v1i64_v1f16(<1 x half> %x) {
+; CHECK-LABEL: llrint_v1i64_v1f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    frintx h0, h0
+; CHECK-NEXT:    fcvtzs x8, h0
+; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    ret
+  %a = call <1 x i64> @llvm.llrint.v1i64.v1f16(<1 x half> %x)
+  ret <1 x i64> %a
+}
+declare <1 x i64> @llvm.llrint.v1i64.v1f16(<1 x half>)
+
+define <2 x i64> @llrint_v1i64_v2f16(<2 x half> %x) {
+; CHECK-LABEL: llrint_v1i64_v2f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    mov h1, v0.h[1]
+; CHECK-NEXT:    frintx h0, h0
+; CHECK-NEXT:    frintx h1, h1
+; CHECK-NEXT:    fcvtzs x8, h0
+; CHECK-NEXT:    fcvtzs x9, h1
+; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    mov v0.d[1], x9
+; CHECK-NEXT:    ret
+  %a = call <2 x i64> @llvm.llrint.v2i64.v2f16(<2 x half> %x)
+  ret <2 x i64> %a
+}
+declare <2 x i64> @llvm.llrint.v2i64.v2f16(<2 x half>)
+
+define <4 x i64> @llrint_v4i64_v4f16(<4 x half> %x) {
+; CHECK-LABEL: llrint_v4i64_v4f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    mov h1, v0.h[2]
+; CHECK-NEXT:    mov h2, v0.h[1]
+; CHECK-NEXT:    mov h3, v0.h[3]
+; CHECK-NEXT:    frintx h0, h0
+; CHECK-NEXT:    frintx h1, h1
+; CHECK-NEXT:    frintx h2, h2
+; CHECK-NEXT:    frintx h3, h3
+; CHECK-NEXT:    fcvtzs x8, h0
+; CHECK-NEXT:    fcvtzs x9, h1
+; CHECK-NEXT:    fcvtzs x10, h2
+; CHECK-NEXT:    fcvtzs x11, h3
+; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    fmov d1, x9
+; CHECK-NEXT:    mov v0.d[1], x10
+; CHECK-NEXT:    mov v1.d[1], x11
+; CHECK-NEXT:    ret
+  %a = call <4 x i64> @llvm.llrint.v4i64.v4f16(<4 x half> %x)
+  ret <4 x i64> %a
+}
+declare <4 x i64> @llvm.llrint.v4i64.v4f16(<4 x half>)
+
+define <8 x i64> @llrint_v8i64_v8f16(<8 x half> %x) {
+; CHECK-LABEL: llrint_v8i64_v8f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT:    mov h4, v0.h[2]
+; CHECK-NEXT:    mov h3, v0.h[1]
+; CHECK-NEXT:    mov h7, v0.h[3]
+; CHECK-NEXT:    frintx h0, h0
+; CHECK-NEXT:    mov h2, v1.h[2]
+; CHECK-NEXT:    mov h5, v1.h[1]
+; CHECK-NEXT:    mov h6, v1.h[3]
+; CHECK-NEXT:    frintx h1, h1
+; CHECK-NEXT:    frintx h4, h4
+; CHECK-NEXT:    frintx h3, h3
+; CHECK-NEXT:    frintx h7, h7
+; CHECK-NEXT:    fcvtzs x9, h0
+; CHECK-NEXT:    frintx h2, h2
+; CHECK-NEXT:    frintx h5, h5
+; CHECK-NEXT:    frintx h6, h6
+; CHECK-NEXT:    fcvtzs x8, h1
+; CHECK-NEXT:    fcvtzs x12, h4
+; CHECK-NEXT:    fcvtzs x11, h3
+; CHECK-NEXT:    fcvtzs x15, h7
+; CHECK-NEXT:    fmov d0, x9
+; CHECK-NEXT:    fcvtzs x10, h2
+; CHECK-NEXT:    fcvtzs x13, h5
+; CHECK-NEXT:    fcvtzs x14, h6
+; CHECK-NEXT:    fmov d2, x8
+; CHECK-NEXT:    fmov d1, x12
+; CHECK-NEXT:    mov v0.d[1], x11
+; CHECK-NEXT:    fmov d3, x10
+; CHECK-NEXT:    mov v2.d[1], x13
+; CHECK-NEXT:    mov v1.d[1], x15
+; CHECK-NEXT:    mov v3.d[1], x14
+; CHECK-NEXT:    ret
+  %a = call <8 x i64> @llvm.llrint.v8i64.v8f16(<8 x half> %x)
+  ret <8 x i64> %a
+}
+declare <8 x i64> @llvm.llrint.v8i64.v8f16(<8 x half>)
+
+define <16 x i64> @llrint_v16i64_v16f16(<16 x half> %x) {
+; CHECK-LABEL: llrint_v16i64_v16f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ext v2.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT:    ext v3.16b, v1.16b, v1.16b, #8
+; CHECK-NEXT:    mov h4, v0.h[1]
+; CHECK-NEXT:    frintx h5, h0
+; CHECK-NEXT:    mov h18, v0.h[2]
+; CHECK-NEXT:    mov h0, v0.h[3]
+; CHECK-NEXT:    frintx h6, h2
+; CHECK-NEXT:    mov h7, v2.h[1]
+; CHECK-NEXT:    mov h16, v2.h[2]
+; CHECK-NEXT:    mov h17, v3.h[2]
+; CHECK-NEXT:    frintx h19, h3
+; CHECK-NEXT:    frintx h4, h4
+; CHECK-NEXT:    fcvtzs x8, h5
+; CHECK-NEXT:    mov h5, v1.h[1]
+; CHECK-NEXT:    mov h2, v2.h[3]
+; CHECK-NEXT:    frintx h18, h18
+; CHECK-NEXT:    frintx h0, h0
+; CHECK-NEXT:    fcvtzs x9, h6
+; CHECK-NEXT:    frintx h6, h7
+; CHECK-NEXT:    frintx h7, h16
+; CHECK-NEXT:    mov h16, v1.h[2]
+; CHECK-NEXT:    frintx h17, h17
+; CHECK-NEXT:    fcvtzs x10, h19
+; CHECK-NEXT:    mov h19, v3.h[1]
+; CHECK-NEXT:    fcvtzs x11, h4
+; CHECK-NEXT:    mov h4, v1.h[3]
+; CHECK-NEXT:    mov h3, v3.h[3]
+; CHECK-NEXT:    frintx h1, h1
+; CHECK-NEXT:    frintx h5, h5
+; CHECK-NEXT:    fcvtzs x13, h7
+; CHECK-NEXT:    fcvtzs x12, h6
+; CHECK-NEXT:    fcvtzs x15, h18
+; CHECK-NEXT:    frintx h7, h16
+; CHECK-NEXT:    fcvtzs x14, h17
+; CHECK-NEXT:    frintx h16, h2
+; CHECK-NEXT:    frintx h17, h19
+; CHECK-NEXT:    frintx h4, h4
+; CHECK-NEXT:    fmov d2, x9
+; CHECK-NEXT:    frintx h19, h3
+; CHECK-NEXT:    fcvtzs x9, h1
+; CHECK-NEXT:    fmov d6, x10
+; CHECK-NEXT:    fmov d3, x13
+; CHECK-NEXT:    fcvtzs x13, h0
+; CHECK-NEXT:    fcvtzs x16, h5
+; CHECK-NEXT:    fcvtzs x10, h7
+; CHECK-NEXT:    fmov d7, x14
+; CHECK-NEXT:    fcvtzs x14, h16
+; CHECK-NEXT:    fcvtzs x17, h17
+; CHECK-NEXT:    fcvtzs x0, h4
+; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    fcvtzs x18, h19
+; CHECK-NEXT:    fmov d1, x15
+; CHECK-NEXT:    fmov d4, x9
+; CHECK-NEXT:    mov v2.d[1], x12
+; CHECK-NEXT:    fmov d5, x10
+; CHECK-NEXT:    mov v0.d[1], x11
+; CHECK-NEXT:    mov v3.d[1], x14
+; CHECK-NEXT:    mov v1.d[1], x13
+; CHECK-NEXT:    mov v4.d[1], x16
+; CHECK-NEXT:    mov v6.d[1], x17
+; CHECK-NEXT:    mov v7.d[1], x18
+; CHECK-NEXT:    mov v5.d[1], x0
+; CHECK-NEXT:    ret
+  %a = call <16 x i64> @llvm.llrint.v16i64.v16f16(<16 x half> %x)
+  ret <16 x i64> %a
+}
+declare <16 x i64> @llvm.llrint.v16i64.v16f16(<16 x half>)
+
+define <32 x i64> @llrint_v32i64_v32f16(<32 x half> %x) {
+; CHECK-LABEL: llrint_v32i64_v32f16:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    ext v4.16b, v1.16b, v1.16b, #8
+; CHECK-NEXT:    ext v5.16b, v2.16b, v2.16b, #8
+; CHECK-NEXT:    ext v6.16b, v3.16b, v3.16b, #8
+; CHECK-NEXT:    ext v7.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT:    frintx h21, h1
+; CHECK-NEXT:    frintx h22, h2
+; CHECK-NEXT:    mov h26, v2.h[2]
+; CHECK-NEXT:    frintx h19, h0
+; CHECK-NEXT:    mov h27, v3.h[2]
+; CHECK-NEXT:    mov h20, v2.h[1]
+; CHECK-NEXT:    mov h18, v1.h[1]
+; CHECK-NEXT:    mov h16, v4.h[2]
+; CHECK-NEXT:    mov h17, v5.h[2]
+; CHECK-NEXT:    frintx h23, h5
+; CHECK-NEXT:    frintx h24, h6
+; CHECK-NEXT:    mov h25, v6.h[2]
+; CHECK-NEXT:    fcvtzs x9, h21
+; CHECK-NEXT:    fcvtzs x11, h22
+; CHECK-NEXT:    frintx h22, h7
+; CHECK-NEXT:    mov h21, v3.h[3]
+; CHECK-NEXT:    fcvtzs x10, h19
+; CHECK-NEXT:    frintx h27, h27
+; CHECK-NEXT:    frintx h20, h20
+; CHECK-NEXT:    frintx h16, h16
+; CHECK-NEXT:    frintx h17, h17
+; CHECK-NEXT:    fcvtzs x12, h23
+; CHECK-NEXT:    fcvtzs x13, h24
+; CHECK-NEXT:    frintx h23, h25
+; CHECK-NEXT:    frintx h25, h26
+; CHECK-NEXT:    mov h26, v3.h[1]
+; CHECK-NEXT:    mov h24, v2.h[3]
+; CHECK-NEXT:    fmov d19, x9
+; CHECK-NEXT:    fcvtzs x9, h22
+; CHECK-NEXT:    frintx h22, h3
+; CHECK-NEXT:    frintx h21, h21
+; CHECK-NEXT:    fcvtzs x14, h16
+; CHECK-NEXT:    fcvtzs x15, h17
+; CHECK-NEXT:    fmov d2, x12
+; CHECK-NEXT:    fmov d16, x13
+; CHECK-NEXT:    fcvtzs x12, h23
+; CHECK-NEXT:    fcvtzs x13, h25
+; CHECK-NEXT:    mov h23, v1.h[2]
+; CHECK-NEXT:    frintx h25, h26
+; CHECK-NEXT:    frintx h24, h24
+; CHECK-NEXT:    mov h1, v1.h[3]
+; CHECK-NEXT:    fmov d26, x11
+; CHECK-NEXT:    fcvtzs x11, h21
+; CHECK-NEXT:    fmov d3, x14
+; CHECK-NEXT:    fmov d17, x15
+; CHECK-NEXT:    fcvtzs x14, h22
+; CHECK-NEXT:    fcvtzs x15, h27
+; CHECK-NEXT:    mov h22, v0.h[2]
+; CHECK-NEXT:    frintx h18, h18
+; CHECK-NEXT:    frintx h21, h23
+; CHECK-NEXT:    fmov d23, x13
+; CHECK-NEXT:    fcvtzs x13, h25
+; CHECK-NEXT:    frintx h1, h1
+; CHECK-NEXT:    fmov d25, x14
+; CHECK-NEXT:    fcvtzs x14, h24
+; CHECK-NEXT:    fmov d24, x15
+; CHECK-NEXT:    frintx h22, h22
+; CHECK-NEXT:    fcvtzs x15, h18
+; CHECK-NEXT:    mov h18, v7.h[1]
+; CHECK-NEXT:    mov v25.d[1], x13
+; CHECK-NEXT:    fcvtzs x13, h21
+; CHECK-NEXT:    mov h21, v7.h[2]
+; CHECK-NEXT:    mov v24.d[1], x11
+; CHECK-NEXT:    fcvtzs x11, h20
+; CHECK-NEXT:    mov h20, v0.h[1]
+; CHECK-NEXT:    mov h0, v0.h[3]
+; CHECK-NEXT:    mov v23.d[1], x14
+; CHECK-NEXT:    fcvtzs x14, h1
+; CHECK-NEXT:    mov h1, v6.h[3]
+; CHECK-NEXT:    mov h6, v6.h[1]
+; CHECK-NEXT:    mov v19.d[1], x15
+; CHECK-NEXT:    mov h7, v7.h[3]
+; CHECK-NEXT:    stp q25, q24, [x8, #192]
+; CHECK-NEXT:    fmov d24, x13
+; CHECK-NEXT:    frintx h20, h20
+; CHECK-NEXT:    mov v26.d[1], x11
+; CHECK-NEXT:    fcvtzs x11, h22
+; CHECK-NEXT:    mov h22, v5.h[1]
+; CHECK-NEXT:    mov h5, v5.h[3]
+; CHECK-NEXT:    frintx h0, h0
+; CHECK-NEXT:    frintx h1, h1
+; CHECK-NEXT:    mov v24.d[1], x14
+; CHECK-NEXT:    mov h25, v4.h[3]
+; CHECK-NEXT:    frintx h6, h6
+; CHECK-NEXT:    stp q26, q23, [x8, #128]
+; CHECK-NEXT:    fmov d23, x12
+; CHECK-NEXT:    fcvtzs x12, h20
+; CHECK-NEXT:    mov h20, v4.h[1]
+; CHECK-NEXT:    frintx h5, h5
+; CHECK-NEXT:    fcvtzs x13, h0
+; CHECK-NEXT:    stp q19, q24, [x8, #64]
+; CHECK-NEXT:    frintx h22, h22
+; CHECK-NEXT:    fmov d0, x10
+; CHECK-NEXT:    fmov d19, x11
+; CHECK-NEXT:    frintx h4, h4
+; CHECK-NEXT:    fcvtzs x10, h1
+; CHECK-NEXT:    frintx h1, h21
+; CHECK-NEXT:    frintx h24, h25
+; CHECK-NEXT:    fcvtzs x11, h6
+; CHECK-NEXT:    frintx h20, h20
+; CHECK-NEXT:    frintx h6, h7
+; CHECK-NEXT:    fcvtzs x14, h5
+; CHECK-NEXT:    mov v19.d[1], x13
+; CHECK-NEXT:    frintx h5, h18
+; CHECK-NEXT:    fcvtzs x13, h22
+; CHECK-NEXT:    mov v0.d[1], x12
+; CHECK-NEXT:    fcvtzs x12, h4
+; CHECK-NEXT:    mov v23.d[1], x10
+; CHECK-NEXT:    fcvtzs x10, h1
+; CHECK-NEXT:    fcvtzs x15, h24
+; CHECK-NEXT:    mov v16.d[1], x11
+; CHECK-NEXT:    fcvtzs x11, h20
+; CHECK-NEXT:    mov v17.d[1], x14
+; CHECK-NEXT:    fcvtzs x14, h6
+; CHECK-NEXT:    mov v2.d[1], x13
+; CHECK-NEXT:    fcvtzs x13, h5
+; CHECK-NEXT:    fmov d4, x9
+; CHECK-NEXT:    stp q0, q19, [x8]
+; CHECK-NEXT:    fmov d0, x12
+; CHECK-NEXT:    stp q16, q23, [x8, #224]
+; CHECK-NEXT:    fmov d1, x10
+; CHECK-NEXT:    mov v3.d[1], x15
+; CHECK-NEXT:    stp q2, q17, [x8, #160]
+; CHECK-NEXT:    mov v0.d[1], x11
+; CHECK-NEXT:    mov v4.d[1], x13
+; CHECK-NEXT:    mov v1.d[1], x14
+; CHECK-NEXT:    stp q0, q3, [x8, #96]
+; CHECK-NEXT:    stp q4, q1, [x8, #32]
+; CHECK-NEXT:    ret
+  %a = call <32 x i64> @llvm.llrint.v32i64.v32f16(<32 x half> %x)
+  ret <32 x i64> %a
+}
+declare <32 x i64> @llvm.llrint.v32i64.v32f16(<32 x half>)
+
+define <1 x i64> @llrint_v1i64_v1f32(<1 x float> %x) {
+; CHECK-LABEL: llrint_v1i64_v1f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    // kill: def $d0 killed $d0 def $q0
+; CHECK-NEXT:    frintx s0, s0
+; CHECK-NEXT:    fcvtzs x8, s0
+; CHECK-NEXT:    fmov d0, x8
+; CHECK-NEXT:    ret
+  %a = call <1 x i64> @llvm.llrint.v1i64.v1f32(<1 x float> %x)
+  ret <1 x i64> %a
+}
+declare <1 x i64> @llvm.llrint.v1i64.v1f32(<1 x float>)
+
+define <2 x i64> @llrint_v2i64_v2f32(<2 x float> %x) {
+; CHECK-LABEL: llrint_v2i64_v2f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    addpl x8, sp, #4
+; CHECK-NEXT:    str d0, [x8]
+; CHECK-NEXT:    ld1w { z0.d }, p0/z, [sp, #1, mul vl]
+; CHECK-NEXT:    fcvtzs z0.d, p0/m, z0.s
+; CHECK-NEXT:    // kill: def $q0 killed $q0 killed $z0
+; CHECK-NEXT:    addvl sp, sp, #1
+; CHECK-NEXT:    ldr x29, [sp], #16 // 8-byte Folded Reload
+; CHECK-NEXT:    ret
+  %a = call <2 x i64> @llvm.llrint.v2i64.v2f32(<2 x float> %x)
+  ret <2 x i64> %a
+}
+declare <2 x i64> @llvm.llrint.v2i64.v2f32(<2 x float>)
+
+define <4 x i64> @llrint_v4i64_v4f32(<4 x float> %x) {
+; CHECK-LABEL: llrint_v4i64_v4f32:
+; CHECK:       // %bb.0:
+; CHECK-NEXT:    str x29, [sp, #-16]! // 8-byte Folded Spill
+; CHECK-NEXT:    addvl sp, sp, #-1
+; CHECK-NEXT:    .cfi_escape 0x0f, 0x0c, 0x8f, 0x00, 0x11, 0x10, 0x22, 0x11, 0x08, 0x92, 0x2e, 0x00, 0x1e, 0x22 // sp + 16 + 8 * VG
+; CHECK-NEXT:    .cfi_offset w29, -16
+; CHECK-NEXT:    ptrue p0.d
+; CHECK-NEXT:    ext v1.16b, v0.16b, v0.16b, #8
+; CHECK-NEXT:    addpl x8, sp, #4
+; CHECK-NEXT:    str d0, [sp]
+; CHECK-NEXT:    str d1, [x8]
+; CHECK-NEXT:    ld1w { z0.d }, p0/z, [sp]
+; CHECK-NEXT:    ld1w { z1.d }, ...
[truncated]

@davemgreen
Copy link
Collaborator

Can you split the Neon and SVE parts into separate patches, just adding the hopefully simpler Neon versions first? SVE tests are usually in files with names like sve-lrint.ll.

@artagnon artagnon force-pushed the isel-xrint-aarch64 branch from 741b358 to f19a306 Compare May 7, 2024 17:45
@artagnon artagnon changed the title ISel/AArch64: custom lower vector ISD::LRINT, ISD::LLRINT ISel/AArch64/SVE: custom lower vector ISD::[L]LRINT May 7, 2024
@artagnon
Copy link
Contributor Author

artagnon commented May 7, 2024

I have investigated the possibility of lowering ISD::[L]LRINT when only neon is present, and have concluded that it is impossible. I have updated the patch, and renamed the tests appropriately. On neon, there is no custom lowering, and the custom lowering on sve is complete.

@davemgreen
Copy link
Collaborator

Hi - Can we expand LRINT/LLRINT to just do frintx + [fcvtl] + fcvtzs? It seems that we scalarize the vector operations at the moment, and the scalar operations work that way.

Could we just be lowering LRINT/LLRINT to ISD::FRINT+ISD::FP_TO_SINT and let the rest of legalization deal with the results? It should then hopefully be able to handle any vector types without much extra work.

@artagnon artagnon force-pushed the isel-xrint-aarch64 branch from f19a306 to 5af6d4e Compare May 8, 2024 17:29
Copy link

github-actions bot commented May 8, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@artagnon
Copy link
Contributor Author

artagnon commented May 8, 2024

Hi,

I looked at the code of FP_TO_SINT before deciding that it was impossible to custom-lower on neon (what I mean here is that, on neon, the vector will get unrolled, and the corresponding scalar operations will be generated). I've now updated the patch to use FP_TO_SINT, but there are no test updates. Since the added patch adds an extra SelectionDAG overhead without any benefit, I'd be tempted to go with my original version: however, if you think the code for lowerVectorXRINT should be simpler, let me know.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks I think this sounds good. I'm not quite sure about the fixed-length sve vectors, but hopefully if that works and we add the Custom Neon legalization it should improve the code quite a bit.

artagnon added 4 commits May 9, 2024 13:43
Since 98c90a1 (ISel: introduce vector ISD::LRINT, ISD::LLRINT; custom
RISCV lowering), ISD::LRINT and ISD::LLRINT now have vector variants,
that are custom lowered on RISCV, and scalarized on all other targets.
Since 2302e4c (Reland "VectorUtils: mark xrint as trivially
vectorizable"), lrint and llrint are trivially vectorizable, so all the
vectorizers in-tree will produce vector variants when possible. Add a
custom lowering for AArch64 to custom-lower the vector variants natively
using a combination of frintx, fcvte, and fcvtzs, when SVE is present.
@artagnon artagnon force-pushed the isel-xrint-aarch64 branch from a73d606 to 9a92adf Compare May 9, 2024 14:28
@artagnon artagnon changed the title ISel/AArch64/SVE: custom lower vector ISD::[L]LRINT ISel/AArch64: custom lower vector ISD::[L]LRINT May 9, 2024
@artagnon
Copy link
Contributor Author

artagnon commented May 9, 2024

Thanks for all the reviews! I've now fixed all the issues, and the patch is hopefully ready.

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this does looks good to me. We usually add Neon codegen before adding SVE and fixed-length SVE on top to help keep things simple, but as the code in this case is the same for both I think this should be OK all at once.

I would recommend moving where/what gets marked as Custom.

The other thing to consider is that as long can be i32 on some platforms, there could be i32 variants as well as the i64's versions. Considering the code isn't specific to a certain size I believe that should be OK.

@artagnon
Copy link
Contributor Author

artagnon commented May 9, 2024

Thanks for the reviews and the explanations! I have now fixed all the issues, and regenerated the tests. There is one final TODO item left, which is somewhat orthogonal to the patch: i32 doesn't work on sve-lrint (scalable vectors with i32 cause problems for some strange reason on SVE).

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a shame about using FP_TO_SI_SAT and it scalarizing as a result, but we will have to fix that separately in any case. Some of the tests are missing check lines, but otherwise I think this looks OK

Copy link
Collaborator

@davemgreen davemgreen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. LGTM

@artagnon artagnon merged commit 91feb13 into llvm:main May 10, 2024
3 of 4 checks passed
@artagnon artagnon deleted the isel-xrint-aarch64 branch May 10, 2024 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants