Skip to content

Commit 26ff681

Browse files
committed
[X86][FP16] Adding lowerings for FP16 ISD::LRINT and ISD::LLRINT
Address comment in llvm#126477
1 parent 3e3af86 commit 26ff681

File tree

4 files changed

+2839
-7
lines changed

4 files changed

+2839
-7
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -735,6 +735,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
735735
setOperationAction(ISD::FCANONICALIZE, MVT::f16, Custom);
736736
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Custom);
737737
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f64, Custom);
738+
setOperationAction(ISD::LRINT, MVT::f16, Expand);
739+
setOperationAction(ISD::LLRINT, MVT::f16, Expand);
738740

739741
setLibcallName(RTLIB::FPROUND_F32_F16, "__truncsfhf2");
740742
setLibcallName(RTLIB::FPEXT_F16_F32, "__extendhfsf2");
@@ -2312,6 +2314,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23122314
setOperationAction(ISD::FMINIMUMNUM, MVT::f16, Custom);
23132315
setOperationAction(ISD::FP_EXTEND, MVT::f32, Legal);
23142316
setOperationAction(ISD::STRICT_FP_EXTEND, MVT::f32, Legal);
2317+
setOperationAction(ISD::LRINT, MVT::f16, Legal);
2318+
setOperationAction(ISD::LLRINT, MVT::f16, Legal);
23152319

23162320
setCondCodeAction(ISD::SETOEQ, MVT::f16, Expand);
23172321
setCondCodeAction(ISD::SETUNE, MVT::f16, Expand);
@@ -2359,6 +2363,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
23592363
setOperationAction(ISD::FMAXIMUM, MVT::v32f16, Custom);
23602364
setOperationAction(ISD::FMINIMUMNUM, MVT::v32f16, Custom);
23612365
setOperationAction(ISD::FMAXIMUMNUM, MVT::v32f16, Custom);
2366+
setOperationAction(ISD::LRINT, MVT::v32f16, Legal);
2367+
setOperationAction(ISD::LLRINT, MVT::v8f16, Legal);
23622368
}
23632369

23642370
if (Subtarget.hasVLX()) {
@@ -2413,6 +2419,8 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
24132419
setOperationAction(ISD::FMAXIMUM, MVT::v16f16, Custom);
24142420
setOperationAction(ISD::FMINIMUMNUM, MVT::v16f16, Custom);
24152421
setOperationAction(ISD::FMAXIMUMNUM, MVT::v16f16, Custom);
2422+
setOperationAction(ISD::LRINT, MVT::v8f16, Legal);
2423+
setOperationAction(ISD::LRINT, MVT::v16f16, Legal);
24162424
}
24172425
}
24182426

@@ -34055,8 +34063,15 @@ void X86TargetLowering::ReplaceNodeResults(SDNode *N,
3405534063
case ISD::LRINT:
3405634064
if (N->getValueType(0) == MVT::v2i32) {
3405734065
SDValue Src = N->getOperand(0);
34058-
if (Src.getValueType() == MVT::v2f64)
34059-
Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
34066+
if (Subtarget.hasFP16() && Src.getValueType() == MVT::v2f16) {
34067+
Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v4f16, Src,
34068+
DAG.getUNDEF(MVT::v2f16));
34069+
Src = DAG.getNode(ISD::CONCAT_VECTORS, dl, MVT::v8f16, Src,
34070+
DAG.getUNDEF(MVT::v4f16));
34071+
} else if (Src.getValueType() != MVT::v2f64) {
34072+
return;
34073+
}
34074+
Results.push_back(DAG.getNode(X86ISD::CVTP2SI, dl, MVT::v4i32, Src));
3406034075
return;
3406134076
}
3406234077
[[fallthrough]];
@@ -53640,13 +53655,35 @@ static SDValue combineLRINT_LLRINT(SDNode *N, SelectionDAG &DAG,
5364053655
EVT SrcVT = Src.getValueType();
5364153656
SDLoc DL(N);
5364253657

53643-
if (!Subtarget.hasDQI() || !Subtarget.hasVLX() || VT != MVT::v2i64 ||
53644-
SrcVT != MVT::v2f32)
53658+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
53659+
53660+
// Let legalize expand this if it isn't a legal type yet.
53661+
if (!TLI.isTypeLegal(VT))
53662+
return SDValue();
53663+
53664+
if ((SrcVT.getScalarType() == MVT::f16 && !Subtarget.hasFP16()) ||
53665+
(SrcVT.getScalarType() == MVT::f32 && !Subtarget.hasDQI()))
5364553666
return SDValue();
5364653667

53647-
return DAG.getNode(X86ISD::CVTP2SI, DL, VT,
53648-
DAG.getNode(ISD::CONCAT_VECTORS, DL, MVT::v4f32, Src,
53649-
DAG.getUNDEF(SrcVT)));
53668+
if (SrcVT == MVT::v2f16) {
53669+
SrcVT = MVT::v4f16;
53670+
Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53671+
DAG.getUNDEF(MVT::v2f16));
53672+
}
53673+
53674+
if (SrcVT == MVT::v4f16) {
53675+
SrcVT = MVT::v8f16;
53676+
Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53677+
DAG.getUNDEF(MVT::v4f16));
53678+
} else if (SrcVT == MVT::v2f32) {
53679+
SrcVT = MVT::v4f32;
53680+
Src = DAG.getNode(ISD::CONCAT_VECTORS, DL, SrcVT, Src,
53681+
DAG.getUNDEF(MVT::v2f32));
53682+
} else {
53683+
return SDValue();
53684+
}
53685+
53686+
return DAG.getNode(X86ISD::CVTP2SI, DL, VT, Src);
5365053687
}
5365153688

5365253689
/// Attempt to pre-truncate inputs to arithmetic ops if it will simplify

llvm/lib/Target/X86/X86InstrAVX512.td

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13143,6 +13143,26 @@ defm VCVTTPH2UQQ : avx512_cvttph2qq<0x78, "vcvttph2uqq", X86any_cvttp2ui,
1314313143
SchedWriteCvtPS2DQ>, T_MAP5, PD,
1314413144
EVEX_CD8<16, CD8VQ>;
1314513145

13146+
let Predicates = [HasFP16, HasVLX] in {
13147+
def : Pat<(v8i16 (lrint (v8f16 VR128X:$src))), (VCVTPH2WZ128rr VR128X:$src)>;
13148+
def : Pat<(v8i16 (lrint (loadv8f16 addr:$src))), (VCVTPH2WZ128rm addr:$src)>;
13149+
def : Pat<(v16i16 (lrint (v16f16 VR256X:$src))), (VCVTPH2WZ256rr VR256X:$src)>;
13150+
def : Pat<(v16i16 (lrint (loadv16f16 addr:$src))), (VCVTPH2WZ256rm addr:$src)>;
13151+
def : Pat<(v8i32 (lrint (v8f16 VR128X:$src))), (VCVTPH2DQZ256rr VR128X:$src)>;
13152+
def : Pat<(v8i32 (lrint (loadv8f16 addr:$src))), (VCVTPH2DQZ256rm addr:$src)>;
13153+
}
13154+
13155+
let Predicates = [HasFP16] in {
13156+
def : Pat<(v32i16 (lrint (v32f16 VR512:$src))), (VCVTPH2WZrr VR512:$src)>;
13157+
def : Pat<(v32i16 (lrint (loadv32f16 addr:$src))), (VCVTPH2WZrm addr:$src)>;
13158+
def : Pat<(v16i32 (lrint (v16f16 VR256X:$src))), (VCVTPH2DQZrr VR256X:$src)>;
13159+
def : Pat<(v16i32 (lrint (loadv16f16 addr:$src))), (VCVTPH2DQZrm addr:$src)>;
13160+
def : Pat<(v8i64 (lrint (v8f16 VR128X:$src))), (VCVTPH2QQZrr VR128X:$src)>;
13161+
def : Pat<(v8i64 (lrint (loadv8f16 addr:$src))), (VCVTPH2QQZrm addr:$src)>;
13162+
def : Pat<(v8i64 (llrint (v8f16 VR128X:$src))), (VCVTPH2QQZrr VR128X:$src)>;
13163+
def : Pat<(v8i64 (llrint (loadv8f16 addr:$src))), (VCVTPH2QQZrm addr:$src)>;
13164+
}
13165+
1314613166
// Convert Signed/Unsigned Quardword to Half
1314713167
multiclass avx512_cvtqq2ph<bits<8> opc, string OpcodeStr, SDPatternOperator OpNode,
1314813168
SDPatternOperator MaskOpNode, SDNode OpNodeRnd,
@@ -13269,6 +13289,19 @@ defm VCVTTSH2USI64Z: avx512_cvt_s_all<0x78, "vcvttsh2usi", f16x_info, i64x_info,
1326913289
any_fp_to_uint, X86cvtts2UInt, X86cvtts2UIntSAE, WriteCvtSS2I,
1327013290
"{q}", HasFP16>, T_MAP5, XS, REX_W, EVEX_CD8<16, CD8VT1>;
1327113291

13292+
let Predicates = [HasFP16] in {
13293+
def : Pat<(i16 (lrint FR16:$src)), (EXTRACT_SUBREG (VCVTTSH2SIZrr FR16:$src), sub_16bit)>;
13294+
def : Pat<(i32 (lrint FR16:$src)), (VCVTTSH2SIZrr FR16:$src)>;
13295+
def : Pat<(i32 (lrint (loadf16 addr:$src))), (VCVTTSH2SIZrm addr:$src)>;
13296+
}
13297+
13298+
let Predicates = [HasFP16, In64BitMode] in {
13299+
def : Pat<(i64 (lrint FR16:$src)), (VCVTTSH2SI64Zrr FR16:$src)>;
13300+
def : Pat<(i64 (lrint (loadf16 addr:$src))), (VCVTTSH2SI64Zrm addr:$src)>;
13301+
def : Pat<(i64 (llrint FR16:$src)), (VCVTTSH2SI64Zrr FR16:$src)>;
13302+
def : Pat<(i64 (llrint (loadf16 addr:$src))), (VCVTTSH2SI64Zrm addr:$src)>;
13303+
}
13304+
1327213305
let Predicates = [HasFP16] in {
1327313306
defm VCVTSI2SHZ : avx512_vcvtsi_common<0x2A, X86SintToFp, X86SintToFpRnd, WriteCvtI2SS, GR32,
1327413307
v8f16x_info, i32mem, loadi32, "cvtsi2sh", "l">,

0 commit comments

Comments
 (0)