Skip to content

Commit fd3e7e3

Browse files
authored
[X86] Adding lowerings for vector ISD::LRINT and ISD::LLRINT (#90065)
- [V]CVTP[D,S]2DQ supports `f64/f32` -> `i32` conversions that can be mapped to `llvm.lrint.vNi32.vNf64/32` since SSE2. AVX and AVX512 added 256-bit and 512-bit support; - VCVTP[D,S]2QQ supports `f64/f32` -> `i64` conversions that can be mapped to `llvm.l[l]rint.vNi64.vNf64/32` since AVX512DQ. All 128-bit, 256-bit (require AVX512VL) and 512-bit are supported.
1 parent 1949856 commit fd3e7e3

File tree

5 files changed

+911
-695
lines changed

5 files changed

+911
-695
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,6 +1092,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
10921092
setOperationAction(ISD::FABS, MVT::v2f64, Custom);
10931093
setOperationAction(ISD::FCOPYSIGN, MVT::v2f64, Custom);
10941094

1095+
setOperationAction(ISD::LRINT, MVT::v4f32, Custom);
1096+
10951097
for (auto VT : { MVT::v16i8, MVT::v8i16, MVT::v4i32, MVT::v2i64 }) {
10961098
setOperationAction(ISD::SMAX, VT, VT == MVT::v8i16 ? Legal : Custom);
10971099
setOperationAction(ISD::SMIN, VT, VT == MVT::v8i16 ? Legal : Custom);
@@ -1431,6 +1433,9 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
14311433
setOperationAction(ISD::FMINIMUM, VT, Custom);
14321434
}
14331435

1436+
setOperationAction(ISD::LRINT, MVT::v8f32, Custom);
1437+
setOperationAction(ISD::LRINT, MVT::v4f64, Custom);
1438+
14341439
// (fp_to_int:v8i16 (v8f32 ..)) requires the result type to be promoted
14351440
// even though v8i16 is a legal type.
14361441
setOperationPromotedToType(ISD::FP_TO_SINT, MVT::v8i16, MVT::v8i32);
@@ -1731,6 +1736,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17311736
for (auto VT : { MVT::v1i1, MVT::v2i1, MVT::v4i1, MVT::v8i1 })
17321737
setOperationAction(ISD::EXTRACT_SUBVECTOR, VT, Custom);
17331738
}
1739+
if (Subtarget.hasDQI() && Subtarget.hasVLX()) {
1740+
for (MVT VT : {MVT::v4f32, MVT::v8f32, MVT::v2f64, MVT::v4f64}) {
1741+
setOperationAction(ISD::LRINT, VT, Legal);
1742+
setOperationAction(ISD::LLRINT, VT, Legal);
1743+
}
1744+
}
17341745

17351746
// This block controls legalization for 512-bit operations with 8/16/32/64 bit
17361747
// elements. 512-bits can be disabled based on prefer-vector-width and
@@ -1765,6 +1776,12 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
17651776
setOperationAction(ISD::STRICT_FMA, VT, Legal);
17661777
setOperationAction(ISD::FCOPYSIGN, VT, Custom);
17671778
}
1779+
setOperationAction(ISD::LRINT, MVT::v16f32,
1780+
Subtarget.hasDQI() ? Legal : Custom);
1781+
setOperationAction(ISD::LRINT, MVT::v8f64,
1782+
Subtarget.hasDQI() ? Legal : Custom);
1783+
if (Subtarget.hasDQI())
1784+
setOperationAction(ISD::LLRINT, MVT::v8f64, Legal);
17681785

17691786
for (MVT VT : { MVT::v16i1, MVT::v16i8 }) {
17701787
setOperationPromotedToType(ISD::FP_TO_SINT , VT, MVT::v16i32);
@@ -2488,6 +2505,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24882505
ISD::FMAXNUM,
24892506
ISD::SUB,
24902507
ISD::LOAD,
2508+
ISD::LRINT,
2509+
ISD::LLRINT,
24912510
ISD::MLOAD,
24922511
ISD::STORE,
24932512
ISD::MSTORE,
@@ -21161,8 +21180,12 @@ SDValue X86TargetLowering::LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const {
2116121180
SDValue X86TargetLowering::LowerLRINT_LLRINT(SDValue Op,
2116221181
SelectionDAG &DAG) const {
2116321182
SDValue Src = Op.getOperand(0);
21183+
EVT DstVT = Op.getSimpleValueType();
2116421184
MVT SrcVT = Src.getSimpleValueType();
2116521185

21186+
if (SrcVT.isVector())
21187+
return DstVT.getScalarType() == MVT::i32 ? Op : SDValue();
21188+
2116621189
if (SrcVT == MVT::f16)
2116721190
return SDValue();
2116821191

@@ -51542,6 +51565,22 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
5154251565
return SDValue();
5154351566
}
5154451567

51568+
static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
51569+
const X86Subtarget &Subtarget) {
51570+
EVT VT = N->getValueType(0);
51571+
SDValue Src = N->getOperand(0);
51572+
EVT SrcVT = Src.getValueType();
51573+
SDLoc DL(N);
51574+
51575+
if (!Subtarget.hasDQI() || !Subtarget.hasVLX() || VT != MVT::v2i64 ||
51576+
SrcVT != MVT::v2f32)
51577+
return SDValue();
51578+
51579+
return DAG.getNode(X86ISD::CVTP2SI, DL, VT,
51580+
DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, Src,
51581+
DAG.getUNDEF(SrcVT)));
51582+
}
51583+
5154551584
/// Attempt to pre-truncate inputs to arithmetic ops if it will simplify
5154651585
/// the codegen.
5154751586
/// e.g. TRUNC( BINOP( X, Y ) ) --> BINOP( TRUNC( X ), TRUNC( Y ) )
@@ -51888,6 +51927,11 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
5188851927
return DAG.getNode(X86ISD::MMX_MOVD2W, DL, MVT::i32, BCSrc);
5188951928
}
5189051929

51930+
// Try to combine (trunc (vNi64 (lrint x))) to (vNi32 (lrint x)).
51931+
if (Src.getOpcode() == ISD::LRINT && VT.getScalarType() == MVT::i32 &&
51932+
Src.hasOneUse())
51933+
return DAG.getNode(ISD::LRINT, DL, VT, Src.getOperand(0));
51934+
5189151935
return SDValue();
5189251936
}
5189351937

@@ -56834,6 +56878,8 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
5683456878
case ISD::UINT_TO_FP:
5683556879
case ISD::STRICT_UINT_TO_FP:
5683656880
return combineUIntToFP(N, DAG, Subtarget);
56881+
case ISD::LRINT:
56882+
case ISD::LLRINT: return combineLRINT_LLRINT(N, DAG, Subtarget);
5683756883
case ISD::FADD:
5683856884
case ISD::FSUB: return combineFaddFsub(N, DAG, Subtarget);
5683956885
case X86ISD::VFCMULC:

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8811,7 +8811,18 @@ let Predicates = [HasVLX] in {
88118811
def : Pat<(X86mcvttp2ui (v2f64 (X86VBroadcastld64 addr:$src)),
88128812
v4i32x_info.ImmAllZerosV, VK2WM:$mask),
88138813
(VCVTTPD2UDQZ128rmbkz VK2WM:$mask, addr:$src)>;
8814+
8815+
def : Pat<(v4i32 (lrint VR128X:$src)), (VCVTPS2DQZ128rr VR128X:$src)>;
8816+
def : Pat<(v4i32 (lrint (loadv4f32 addr:$src))), (VCVTPS2DQZ128rm addr:$src)>;
8817+
def : Pat<(v8i32 (lrint VR256X:$src)), (VCVTPS2DQZ256rr VR256X:$src)>;
8818+
def : Pat<(v8i32 (lrint (loadv8f32 addr:$src))), (VCVTPS2DQZ256rm addr:$src)>;
8819+
def : Pat<(v4i32 (lrint VR256X:$src)), (VCVTPD2DQZ256rr VR256X:$src)>;
8820+
def : Pat<(v4i32 (lrint (loadv4f64 addr:$src))), (VCVTPD2DQZ256rm addr:$src)>;
88148821
}
8822+
def : Pat<(v16i32 (lrint VR512:$src)), (VCVTPS2DQZrr VR512:$src)>;
8823+
def : Pat<(v16i32 (lrint (loadv16f32 addr:$src))), (VCVTPS2DQZrm addr:$src)>;
8824+
def : Pat<(v8i32 (lrint VR512:$src)), (VCVTPD2DQZrr VR512:$src)>;
8825+
def : Pat<(v8i32 (lrint (loadv8f64 addr:$src))), (VCVTPD2DQZrm addr:$src)>;
88158826

88168827
let Predicates = [HasDQI, HasVLX] in {
88178828
def : Pat<(v2i64 (X86cvtp2Int (bc_v4f32 (v2f64 (X86vzload64 addr:$src))))),
@@ -8857,6 +8868,30 @@ let Predicates = [HasDQI, HasVLX] in {
88578868
(X86cvttp2ui (bc_v4f32 (v2f64 (X86vzload64 addr:$src)))),
88588869
v2i64x_info.ImmAllZerosV)),
88598870
(VCVTTPS2UQQZ128rmkz VK2WM:$mask, addr:$src)>;
8871+
8872+
def : Pat<(v4i64 (lrint VR128X:$src)), (VCVTPS2QQZ256rr VR128X:$src)>;
8873+
def : Pat<(v4i64 (lrint (loadv4f32 addr:$src))), (VCVTPS2QQZ256rm addr:$src)>;
8874+
def : Pat<(v4i64 (llrint VR128X:$src)), (VCVTPS2QQZ256rr VR128X:$src)>;
8875+
def : Pat<(v4i64 (llrint (loadv4f32 addr:$src))), (VCVTPS2QQZ256rm addr:$src)>;
8876+
def : Pat<(v2i64 (lrint VR128X:$src)), (VCVTPD2QQZ128rr VR128X:$src)>;
8877+
def : Pat<(v2i64 (lrint (loadv2f64 addr:$src))), (VCVTPD2QQZ128rm addr:$src)>;
8878+
def : Pat<(v4i64 (lrint VR256X:$src)), (VCVTPD2QQZ256rr VR256X:$src)>;
8879+
def : Pat<(v4i64 (lrint (loadv4f64 addr:$src))), (VCVTPD2QQZ256rm addr:$src)>;
8880+
def : Pat<(v2i64 (llrint VR128X:$src)), (VCVTPD2QQZ128rr VR128X:$src)>;
8881+
def : Pat<(v2i64 (llrint (loadv2f64 addr:$src))), (VCVTPD2QQZ128rm addr:$src)>;
8882+
def : Pat<(v4i64 (llrint VR256X:$src)), (VCVTPD2QQZ256rr VR256X:$src)>;
8883+
def : Pat<(v4i64 (llrint (loadv4f64 addr:$src))), (VCVTPD2QQZ256rm addr:$src)>;
8884+
}
8885+
8886+
let Predicates = [HasDQI] in {
8887+
def : Pat<(v8i64 (lrint VR256X:$src)), (VCVTPS2QQZrr VR256X:$src)>;
8888+
def : Pat<(v8i64 (lrint (loadv8f32 addr:$src))), (VCVTPS2QQZrm addr:$src)>;
8889+
def : Pat<(v8i64 (llrint VR256X:$src)), (VCVTPS2QQZrr VR256X:$src)>;
8890+
def : Pat<(v8i64 (llrint (loadv8f32 addr:$src))), (VCVTPS2QQZrm addr:$src)>;
8891+
def : Pat<(v8i64 (lrint VR512:$src)), (VCVTPD2QQZrr VR512:$src)>;
8892+
def : Pat<(v8i64 (lrint (loadv8f64 addr:$src))), (VCVTPD2QQZrm addr:$src)>;
8893+
def : Pat<(v8i64 (llrint VR512:$src)), (VCVTPD2QQZrr VR512:$src)>;
8894+
def : Pat<(v8i64 (llrint (loadv8f64 addr:$src))), (VCVTPD2QQZrm addr:$src)>;
88608895
}
88618896

88628897
let Predicates = [HasVLX] in {

llvm/lib/Target/X86/X86InstrSSE.td

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,6 @@ def CVTPS2DQrm : PDI<0x5B, MRMSrcMem, (outs VR128:$dst), (ins f128mem:$src),
15541554
(v4i32 (X86cvtp2Int (memopv4f32 addr:$src))))]>,
15551555
Sched<[WriteCvtPS2ILd]>, SIMD_EXC;
15561556

1557-
15581557
// Convert Packed Double FP to Packed DW Integers
15591558
let Predicates = [HasAVX, NoVLX], Uses = [MXCSR], mayRaiseFPException = 1 in {
15601559
// The assembler can recognize rr 256-bit instructions by seeing a ymm
@@ -1586,6 +1585,20 @@ def VCVTPD2DQYrm : SDI<0xE6, MRMSrcMem, (outs VR128:$dst), (ins f256mem:$src),
15861585
VEX, VEX_L, Sched<[WriteCvtPD2IYLd]>, WIG;
15871586
}
15881587

1588+
let Predicates = [HasAVX] in {
1589+
def : Pat<(v4i32 (lrint VR128:$src)), (VCVTPS2DQrr VR128:$src)>;
1590+
def : Pat<(v4i32 (lrint (loadv4f32 addr:$src))), (VCVTPS2DQrm addr:$src)>;
1591+
def : Pat<(v8i32 (lrint VR256:$src)), (VCVTPS2DQYrr VR256:$src)>;
1592+
def : Pat<(v8i32 (lrint (loadv8f32 addr:$src))), (VCVTPS2DQYrm addr:$src)>;
1593+
def : Pat<(v4i32 (lrint VR256:$src)), (VCVTPD2DQYrr VR256:$src)>;
1594+
def : Pat<(v4i32 (lrint (loadv4f64 addr:$src))), (VCVTPD2DQYrm addr:$src)>;
1595+
}
1596+
1597+
let Predicates = [UseSSE2] in {
1598+
def : Pat<(v4i32 (lrint VR128:$src)), (CVTPS2DQrr VR128:$src)>;
1599+
def : Pat<(v4i32 (lrint (loadv4f32 addr:$src))), (CVTPS2DQrm addr:$src)>;
1600+
}
1601+
15891602
def : InstAlias<"vcvtpd2dqx\t{$src, $dst|$dst, $src}",
15901603
(VCVTPD2DQrr VR128:$dst, VR128:$src), 0, "att">;
15911604
def : InstAlias<"vcvtpd2dqy\t{$src, $dst|$dst, $src}",

0 commit comments

Comments
 (0)