Skip to content

Commit 7a1b2ad

Browse files
committed
[RISCV] Implement straight-forward bf16<->int conversion cases
This ports over the test cases half-convert.ll and implements patterns or RISCVISelLowering.cpp changes for all of the most straight-forward cases (those that don't require changes outside of lib/Target/RISCV). The remaining cases and noted poor codegen for saturating conversions will be handled in follow-up patches. Differential Revision: https://reviews.llvm.org/D156943
1 parent c2d1900 commit 7a1b2ad

File tree

3 files changed

+832
-5
lines changed

3 files changed

+832
-5
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2444,9 +2444,10 @@ static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
24442444
bool IsSigned = Op.getOpcode() == ISD::FP_TO_SINT_SAT;
24452445

24462446
if (!DstVT.isVector()) {
2447-
// In absense of Zfh, promote f16 to f32, then saturate the result.
2448-
if (Src.getSimpleValueType() == MVT::f16 &&
2449-
!Subtarget.hasStdExtZfhOrZhinx()) {
2447+
// For bf16 or for f16 in absense of Zfh, promote to f32, then saturate
2448+
// the result.
2449+
if ((Src.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx()) ||
2450+
Src.getValueType() == MVT::bf16) {
24502451
Src = DAG.getNode(ISD::FP_EXTEND, SDLoc(Op), MVT::f32, Src);
24512452
}
24522453

@@ -9813,8 +9814,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
98139814
Results.push_back(Res.getValue(1));
98149815
return;
98159816
}
9816-
// In absense of Zfh, promote f16 to f32, then convert.
9817-
if (Op0.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx())
9817+
// For bf16, or f16 in absense of Zfh, promote [b]f16 to f32 and then
9818+
// convert.
9819+
if ((Op0.getValueType() == MVT::f16 &&
9820+
!Subtarget.hasStdExtZfhOrZhinx()) ||
9821+
Op0.getValueType() == MVT::bf16)
98189822
Op0 = DAG.getNode(ISD::FP_EXTEND, DL, MVT::f32, Op0);
98199823

98209824
unsigned Opc = IsSigned ? RISCVISD::FCVT_W_RV64 : RISCVISD::FCVT_WU_RV64;

llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,25 @@ def : Pat<(bf16 (riscv_fmv_h_x GPR:$src)), (FMV_H_X GPR:$src)>;
6161
def : Pat<(riscv_fmv_x_anyexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
6262
def : Pat<(riscv_fmv_x_signexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
6363
} // Predicates = [HasStdExtZfbfmin]
64+
65+
let Predicates = [HasStdExtZfbfmin, IsRV32] in {
66+
// bf16->[u]int. Round-to-zero must be used for the f32->int step, the
67+
// rounding mode has no effect for bf16->f32.
68+
def : Pat<(i32 (any_fp_to_sint (bf16 FPR16:$rs1))), (FCVT_W_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
69+
def : Pat<(i32 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_WU_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
70+
71+
// [u]int->bf16. Match GCC and default to using dynamic rounding mode.
72+
def : Pat<(bf16 (any_sint_to_fp (i32 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_W $rs1, FRM_DYN), FRM_DYN)>;
73+
def : Pat<(bf16 (any_uint_to_fp (i32 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_WU $rs1, FRM_DYN), FRM_DYN)>;
74+
}
75+
76+
let Predicates = [HasStdExtZfbfmin, IsRV64] in {
77+
// bf16->[u]int64. Round-to-zero must be used for the f32->int step, the
78+
// rounding mode has no effect for bf16->f32.
79+
def : Pat<(i64 (any_fp_to_sint (bf16 FPR16:$rs1))), (FCVT_L_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
80+
def : Pat<(i64 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
81+
82+
// [u]int->bf16. Match GCC and default to using dynamic rounding mode.
83+
def : Pat<(bf16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>;
84+
def : Pat<(bf16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>;
85+
}

0 commit comments

Comments
 (0)