@@ -735,6 +735,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
735
735
setOperationAction(ISD::FCANONICALIZE, MVT::f16, Custom);
736
736
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
737
737
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
738
+ setOperationAction(ISD::LRINT, MVT::f16, Expand);
739
+ setOperationAction(ISD::LLRINT, MVT::f16, Expand);
738
740
739
741
setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2");
740
742
setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
@@ -2312,6 +2314,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
2312
2314
setOperationAction(ISD::FMINIMUMNUM, MVT::f16, Custom);
2313
2315
setOperationAction(ISD::FP_EXTEND, MVT::f32, Legal);
2314
2316
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal);
2317
+ setOperationAction(ISD::LRINT, MVT::f16, Legal);
2318
+ setOperationAction(ISD::LLRINT, MVT::f16, Legal);
2315
2319
2316
2320
setCondCodeAction(ISD::SETOEQ, MVT::f16, Expand);
2317
2321
setCondCodeAction(ISD::SETUNE, MVT::f16, Expand);
@@ -2359,6 +2363,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
2359
2363
setOperationAction(ISD::FMAXIMUM, MVT::v32f16, Custom);
2360
2364
setOperationAction(ISD::FMINIMUMNUM, MVT::v32f16, Custom);
2361
2365
setOperationAction(ISD::FMAXIMUMNUM, MVT::v32f16, Custom);
2366
+ setOperationAction(ISD::LRINT, MVT::v32f16, Legal);
2367
+ setOperationAction(ISD::LLRINT, MVT::v8f16, Legal);
2362
2368
}
2363
2369
2364
2370
if (Subtarget.hasVLX()) {
@@ -2413,6 +2419,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
2413
2419
setOperationAction(ISD::FMAXIMUM, MVT::v16f16, Custom);
2414
2420
setOperationAction(ISD::FMINIMUMNUM, MVT::v16f16, Custom);
2415
2421
setOperationAction(ISD::FMAXIMUMNUM, MVT::v16f16, Custom);
2422
+ setOperationAction(ISD::LRINT, MVT::v8f16, Legal);
2423
+ setOperationAction(ISD::LRINT, MVT::v16f16, Legal);
2416
2424
}
2417
2425
}
2418
2426
@@ -34055,8 +34063,15 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
34055
34063
case ISD::LRINT:
34056
34064
if (N->getValueType(0) == MVT::v2i32) {
34057
34065
SDValue Src = N->getOperand(0);
34058
- if (Src.getValueType() == MVT::v2f64)
34059
- Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
34066
+ if (Subtarget.hasFP16() && Src.getValueType() == MVT::v2f16) {
34067
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f16, Src,
34068
+ DAG.getUNDEF(MVT::v2f16));
34069
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f16, Src,
34070
+ DAG.getUNDEF(MVT::v4f16));
34071
+ } else if (Src.getValueType() != MVT::v2f64) {
34072
+ return;
34073
+ }
34074
+ Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
34060
34075
return;
34061
34076
}
34062
34077
[[fallthrough]];
@@ -53640,13 +53655,35 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
53640
53655
EVT SrcVT = Src.getValueType();
53641
53656
SDLoc DL(N);
53642
53657
53643
- if (!Subtarget.hasDQI() || !Subtarget.hasVLX() || VT != MVT::v2i64 ||
53644
- SrcVT != MVT::v2f32)
53658
+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53659
+
53660
+ // Let legalize expand this if it isn't a legal type yet.
53661
+ if (!TLI.isTypeLegal(VT))
53662
+ return SDValue();
53663
+
53664
+ if ((SrcVT.getScalarType() == MVT::f16 && !Subtarget.hasFP16()) ||
53665
+ (SrcVT.getScalarType() == MVT::f32 && !Subtarget.hasDQI()))
53645
53666
return SDValue();
53646
53667
53647
- return DAG.getNode(X86ISD::CVTP2SI, DL, VT,
53648
- DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, Src,
53649
- DAG.getUNDEF(SrcVT)));
53668
+ if (SrcVT == MVT::v2f16) {
53669
+ SrcVT = MVT::v4f16;
53670
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53671
+ DAG.getUNDEF(MVT::v2f16));
53672
+ }
53673
+
53674
+ if (SrcVT == MVT::v4f16) {
53675
+ SrcVT = MVT::v8f16;
53676
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53677
+ DAG.getUNDEF(MVT::v4f16));
53678
+ } else if (SrcVT == MVT::v2f32) {
53679
+ SrcVT = MVT::v4f32;
53680
+ Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53681
+ DAG.getUNDEF(MVT::v2f32));
53682
+ } else {
53683
+ return SDValue();
53684
+ }
53685
+
53686
+ return DAG.getNode(X86ISD::CVTP2SI, DL, VT, Src);
53650
53687
}
53651
53688
53652
53689
/// Attempt to pre-truncate inputs to arithmetic ops if it will simplify
0 commit comments