Skip to content

[RISCV] Custom promote f16/bf16 (s/u)int_to_fp. #107026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FABS, MVT::bf16, Custom);
setOperationAction(ISD::FNEG, MVT::bf16, Custom);
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Custom);
setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, XLenVT, Custom);
}

if (Subtarget.hasStdExtZfhminOrZhinxmin()) {
Expand All @@ -478,6 +479,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FABS, MVT::f16, Custom);
setOperationAction(ISD::FNEG, MVT::f16, Custom);
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Custom);
setOperationAction({ISD::SINT_TO_FP, ISD::UINT_TO_FP}, XLenVT, Custom);
}

setOperationAction(ISD::STRICT_FP_ROUND, MVT::f16, Legal);
Expand Down Expand Up @@ -590,9 +592,10 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction({ISD::FP_TO_UINT_SAT, ISD::FP_TO_SINT_SAT}, XLenVT,
Custom);

setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT,
ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP},
XLenVT, Legal);
setOperationAction({ISD::STRICT_FP_TO_UINT, ISD::STRICT_FP_TO_SINT}, XLenVT,
Legal);
setOperationAction({ISD::STRICT_UINT_TO_FP, ISD::STRICT_SINT_TO_FP}, XLenVT,
Custom);

setOperationAction(ISD::GET_ROUNDING, XLenVT, Custom);
setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom);
Expand Down Expand Up @@ -2953,6 +2956,33 @@ InstructionCost RISCVTargetLowering::getVSlideVICost(MVT VT) const {
return getLMULCost(VT);
}

static SDValue lowerINT_TO_FP(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
// f16 conversions are promoted to f32 when Zfh/Zhinx are not supported.
// bf16 conversions are always promoted to f32.
if ((Op.getValueType() == MVT::f16 && !Subtarget.hasStdExtZfhOrZhinx()) ||
Op.getValueType() == MVT::bf16) {
bool IsStrict = Op->isStrictFPOpcode();

SDLoc DL(Op);
if (IsStrict) {
SDValue Val = DAG.getNode(Op.getOpcode(), DL, {MVT::f32, MVT::Other},
{Op.getOperand(0), Op.getOperand(1)});
return DAG.getNode(ISD::STRICT_FP_ROUND, DL,
{Op.getValueType(), MVT::Other},
{Val.getValue(1), Val.getValue(0),
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true)});
}
return DAG.getNode(
ISD::FP_ROUND, DL, Op.getValueType(),
DAG.getNode(Op.getOpcode(), DL, MVT::f32, Op.getOperand(0)),
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
}

// Other operations are legal.
return Op;
}

static SDValue lowerFP_TO_INT_SAT(SDValue Op, SelectionDAG &DAG,
const RISCVSubtarget &Subtarget) {
// RISC-V FP-to-int conversions saturate to the destination register size, but
Expand Down Expand Up @@ -6631,13 +6661,15 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
// the source. We custom-lower any conversions that do two hops into
// sequences.
MVT VT = Op.getSimpleValueType();
bool IsStrict = Op->isStrictFPOpcode();
SDValue Src = Op.getOperand(0 + IsStrict);
MVT SrcVT = Src.getSimpleValueType();
if (SrcVT.isScalarInteger())
return lowerINT_TO_FP(Op, DAG, Subtarget);
if (!VT.isVector())
return Op;
SDLoc DL(Op);
bool IsStrict = Op->isStrictFPOpcode();
SDValue Src = Op.getOperand(0 + IsStrict);
MVT EltVT = VT.getVectorElementType();
MVT SrcVT = Src.getSimpleValueType();
MVT SrcEltVT = SrcVT.getVectorElementType();
unsigned EltSize = EltVT.getSizeInBits();
unsigned SrcEltSize = SrcEltVT.getSizeInBits();
Expand Down
8 changes: 0 additions & 8 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,13 @@ let Predicates = [HasStdExtZfbfmin] in {
// rounding mode has no effect for bf16->f32.
def : Pat<(i32 (any_fp_to_sint (bf16 FPR16:$rs1))), (FCVT_W_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
def : Pat<(i32 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_WU_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;

// [u]int->bf16. Match GCC and default to using dynamic rounding mode.
def : Pat<(bf16 (any_sint_to_fp (i32 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_W $rs1, FRM_DYN), FRM_DYN)>;
def : Pat<(bf16 (any_uint_to_fp (i32 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_WU $rs1, FRM_DYN), FRM_DYN)>;
}

let Predicates = [HasStdExtZfbfmin, IsRV64] in {
// bf16->[u]int64. Round-to-zero must be used for the f32->int step, the
// rounding mode has no effect for bf16->f32.
def : Pat<(i64 (any_fp_to_sint (bf16 FPR16:$rs1))), (FCVT_L_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;
def : Pat<(i64 (any_fp_to_uint (bf16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_BF16 $rs1, FRM_RNE), FRM_RTZ)>;

// [u]int->bf16. Match GCC and default to using dynamic rounding mode.
def : Pat<(bf16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>;
def : Pat<(bf16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_BF16_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>;
}

let Predicates = [HasStdExtZfbfmin, HasStdExtD] in {
Expand Down
16 changes: 0 additions & 16 deletions llvm/lib/Target/RISCV/RISCVInstrInfoZfh.td
Original file line number Diff line number Diff line change
Expand Up @@ -604,38 +604,22 @@ let Predicates = [HasStdExtZfhmin, NoStdExtZfh] in {
// half->[u]int. Round-to-zero must be used.
def : Pat<(i32 (any_fp_to_sint (f16 FPR16:$rs1))), (FCVT_W_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;
def : Pat<(i32 (any_fp_to_uint (f16 FPR16:$rs1))), (FCVT_WU_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;

// [u]int->half. Match GCC and default to using dynamic rounding mode.
def : Pat<(f16 (any_sint_to_fp (i32 GPR:$rs1))), (FCVT_H_S (FCVT_S_W $rs1, FRM_DYN), FRM_DYN)>;
def : Pat<(f16 (any_uint_to_fp (i32 GPR:$rs1))), (FCVT_H_S (FCVT_S_WU $rs1, FRM_DYN), FRM_DYN)>;
} // Predicates = [HasStdExtZfhmin, NoStdExtZfh]

let Predicates = [HasStdExtZhinxmin, NoStdExtZhinx] in {
// half->[u]int. Round-to-zero must be used.
def : Pat<(i32 (any_fp_to_sint FPR16INX:$rs1)), (FCVT_W_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;
def : Pat<(i32 (any_fp_to_uint FPR16INX:$rs1)), (FCVT_WU_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;

// [u]int->half. Match GCC and default to using dynamic rounding mode.
def : Pat<(any_sint_to_fp (i32 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_W_INX $rs1, FRM_DYN), FRM_DYN)>;
def : Pat<(any_uint_to_fp (i32 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_WU_INX $rs1, FRM_DYN), FRM_DYN)>;
} // Predicates = [HasStdExtZhinxmin, NoStdExtZhinx]

let Predicates = [HasStdExtZfhmin, NoStdExtZfh, IsRV64] in {
// half->[u]int64. Round-to-zero must be used.
def : Pat<(i64 (any_fp_to_sint (f16 FPR16:$rs1))), (FCVT_L_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;
def : Pat<(i64 (any_fp_to_uint (f16 FPR16:$rs1))), (FCVT_LU_S (FCVT_S_H $rs1, FRM_RNE), FRM_RTZ)>;

// [u]int->fp. Match GCC and default to using dynamic rounding mode.
def : Pat<(f16 (any_sint_to_fp (i64 GPR:$rs1))), (FCVT_H_S (FCVT_S_L $rs1, FRM_DYN), FRM_DYN)>;
def : Pat<(f16 (any_uint_to_fp (i64 GPR:$rs1))), (FCVT_H_S (FCVT_S_LU $rs1, FRM_DYN), FRM_DYN)>;
} // Predicates = [HasStdExtZfhmin, NoStdExtZfh, IsRV64]

let Predicates = [HasStdExtZhinxmin, NoStdExtZhinx, IsRV64] in {
// half->[u]int64. Round-to-zero must be used.
def : Pat<(i64 (any_fp_to_sint FPR16INX:$rs1)), (FCVT_L_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;
def : Pat<(i64 (any_fp_to_uint FPR16INX:$rs1)), (FCVT_LU_S_INX (FCVT_S_H_INX $rs1, FRM_RNE), FRM_RTZ)>;

// [u]int->fp. Match GCC and default to using dynamic rounding mode.
def : Pat<(any_sint_to_fp (i64 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_L_INX $rs1, FRM_DYN), FRM_DYN)>;
def : Pat<(any_uint_to_fp (i64 GPR:$rs1)), (FCVT_H_S_INX (FCVT_S_LU_INX $rs1, FRM_DYN), FRM_DYN)>;
} // Predicates = [HasStdExtZhinxmin, NoStdExtZhinx, IsRV64]
25 changes: 10 additions & 15 deletions llvm/test/CodeGen/RISCV/bfloat-convert.ll
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ define bfloat @fcvt_bf16_si(i16 %a) nounwind {
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: slli a0, a0, 48
; CHECK64ZFBFMIN-NEXT: srai a0, a0, 48
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -795,7 +795,7 @@ define bfloat @fcvt_bf16_si_signext(i16 signext %a) nounwind {
;
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_si_signext:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -845,7 +845,7 @@ define bfloat @fcvt_bf16_ui(i16 %a) nounwind {
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: slli a0, a0, 48
; CHECK64ZFBFMIN-NEXT: srli a0, a0, 48
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -891,7 +891,7 @@ define bfloat @fcvt_bf16_ui_zeroext(i16 zeroext %a) nounwind {
;
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_ui_zeroext:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -935,8 +935,7 @@ define bfloat @fcvt_bf16_w(i32 %a) nounwind {
;
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_w:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: sext.w a0, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -983,7 +982,7 @@ define bfloat @fcvt_bf16_w_load(ptr %p) nounwind {
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_w_load:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: lw a0, 0(a0)
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -1029,9 +1028,7 @@ define bfloat @fcvt_bf16_wu(i32 %a) nounwind {
;
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_wu:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: slli a0, a0, 32
; CHECK64ZFBFMIN-NEXT: srli a0, a0, 32
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -1078,7 +1075,7 @@ define bfloat @fcvt_bf16_wu_load(ptr %p) nounwind {
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_wu_load:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: lwu a0, 0(a0)
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa0, fa5
; CHECK64ZFBFMIN-NEXT: ret
;
Expand Down Expand Up @@ -1376,7 +1373,7 @@ define signext i32 @fcvt_bf16_w_demanded_bits(i32 signext %0, ptr %1) nounwind {
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_w_demanded_bits:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: addiw a0, a0, 1
; CHECK64ZFBFMIN-NEXT: fcvt.s.l fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.s.w fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa5, fa5
; CHECK64ZFBFMIN-NEXT: fsh fa5, 0(a1)
; CHECK64ZFBFMIN-NEXT: ret
Expand Down Expand Up @@ -1436,9 +1433,7 @@ define signext i32 @fcvt_bf16_wu_demanded_bits(i32 signext %0, ptr %1) nounwind
; CHECK64ZFBFMIN-LABEL: fcvt_bf16_wu_demanded_bits:
; CHECK64ZFBFMIN: # %bb.0:
; CHECK64ZFBFMIN-NEXT: addiw a0, a0, 1
; CHECK64ZFBFMIN-NEXT: slli a2, a0, 32
; CHECK64ZFBFMIN-NEXT: srli a2, a2, 32
; CHECK64ZFBFMIN-NEXT: fcvt.s.lu fa5, a2
; CHECK64ZFBFMIN-NEXT: fcvt.s.wu fa5, a0
; CHECK64ZFBFMIN-NEXT: fcvt.bf16.s fa5, fa5
; CHECK64ZFBFMIN-NEXT: fsh fa5, 0(a1)
; CHECK64ZFBFMIN-NEXT: ret
Expand Down
Loading
Loading