Skip to content

Commit a5ce664

Browse files
committed
[RISCV] Remove RISCVISD::FP_ROUND_BF16.
Use isel patterns on regular FP_ROUND. For double->bf16 we need to emit two instructions. Note the double->bf16 conversion does double rounding, but I don't know a good way to fix that.
1 parent 8662714 commit a5ce664

File tree

3 files changed

+7
-38
lines changed

3 files changed

+7
-38
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,6 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
447447
if (Subtarget.hasStdExtZfbfmin()) {
448448
setOperationAction(ISD::BITCAST, MVT::i16, Custom);
449449
setOperationAction(ISD::BITCAST, MVT::bf16, Custom);
450-
setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
451450
setOperationAction(ISD::ConstantFP, MVT::bf16, Expand);
452451
setOperationAction(ISD::SELECT_CC, MVT::bf16, Expand);
453452
setOperationAction(ISD::SELECT, MVT::bf16, Custom);
@@ -6631,30 +6630,9 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
66316630
!Subtarget.hasVInstructionsF16()))
66326631
return SplitVectorOp(Op, DAG);
66336632
return lowerFMAXIMUM_FMINIMUM(Op, DAG, Subtarget);
6634-
case ISD::FP_EXTEND: {
6635-
if (!Op.getValueType().isVector())
6636-
return Op;
6633+
case ISD::FP_EXTEND:
6634+
case ISD::FP_ROUND:
66376635
return lowerVectorFPExtendOrRoundLike(Op, DAG);
6638-
}
6639-
case ISD::FP_ROUND: {
6640-
SDLoc DL(Op);
6641-
EVT VT = Op.getValueType();
6642-
SDValue Op0 = Op.getOperand(0);
6643-
EVT Op0VT = Op0.getValueType();
6644-
if (VT == MVT::bf16 && Op0VT == MVT::f32 && Subtarget.hasStdExtZfbfmin())
6645-
return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, Op0);
6646-
if (VT == MVT::bf16 && Op0VT == MVT::f64 && Subtarget.hasStdExtZfbfmin() &&
6647-
Subtarget.hasStdExtDOrZdinx()) {
6648-
SDValue FloatVal =
6649-
DAG.getNode(ISD::FP_ROUND, DL, MVT::f32, Op0,
6650-
DAG.getIntPtrConstant(0, DL, /*isTarget=*/true));
6651-
return DAG.getNode(RISCVISD::FP_ROUND_BF16, DL, MVT::bf16, FloatVal);
6652-
}
6653-
6654-
if (!Op.getValueType().isVector())
6655-
return Op;
6656-
return lowerVectorFPExtendOrRoundLike(Op, DAG);
6657-
}
66586636
case ISD::STRICT_FP_ROUND:
66596637
case ISD::STRICT_FP_EXTEND:
66606638
return lowerStrictFPExtendOrRoundLike(Op, DAG);
@@ -20588,7 +20566,6 @@ const char *RISCVTargetLowering::getTargetNodeName(unsigned Opcode) const {
2058820566
NODE_NAME_CASE(FCVT_WU_RV64)
2058920567
NODE_NAME_CASE(STRICT_FCVT_W_RV64)
2059020568
NODE_NAME_CASE(STRICT_FCVT_WU_RV64)
20591-
NODE_NAME_CASE(FP_ROUND_BF16)
2059220569
NODE_NAME_CASE(FROUND)
2059320570
NODE_NAME_CASE(FCLASS)
2059420571
NODE_NAME_CASE(FSGNJX)

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,6 @@ enum NodeType : unsigned {
116116
FCVT_W_RV64,
117117
FCVT_WU_RV64,
118118

119-
FP_ROUND_BF16,
120-
121119
// Rounds an FP value to its corresponding integer in the same FP format.
122120
// First operand is the value to round, the second operand is the largest
123121
// integer that can be represented exactly in the FP format. This will be

llvm/lib/Target/RISCV/RISCVInstrInfoZfbfmin.td

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,6 @@
1313
//
1414
//===----------------------------------------------------------------------===//
1515

16-
//===----------------------------------------------------------------------===//
17-
// RISC-V specific DAG Nodes.
18-
//===----------------------------------------------------------------------===//
19-
20-
def SDT_RISCVFP_ROUND_BF16
21-
: SDTypeProfile<1, 1, [SDTCisVT<0, bf16>, SDTCisVT<1, f32>]>;
22-
23-
def riscv_fpround_bf16
24-
: SDNode<"RISCVISD::FP_ROUND_BF16", SDT_RISCVFP_ROUND_BF16>;
25-
2616
//===----------------------------------------------------------------------===//
2717
// Instructions
2818
//===----------------------------------------------------------------------===//
@@ -60,7 +50,7 @@ def : StPat<store, FSH, FPR16, bf16>;
6050

6151
/// Float conversion operations
6252
// f32 -> bf16, bf16 -> f32
63-
def : Pat<(bf16 (riscv_fpround_bf16 FPR32:$rs1)),
53+
def : Pat<(bf16 (fpround FPR32:$rs1)),
6454
(FCVT_BF16_S FPR32:$rs1, FRM_DYN)>;
6555
def : Pat<(fpextend (bf16 FPR16:$rs1)),
6656
(FCVT_S_BF16 FPR16:$rs1, FRM_RNE)>;
@@ -72,6 +62,10 @@ def : Pat<(riscv_fmv_x_signexth (bf16 FPR16:$src)), (FMV_X_H FPR16:$src)>;
7262
} // Predicates = [HasStdExtZfbfmin]
7363

7464
let Predicates = [HasStdExtZfbfmin, HasStdExtD] in {
65+
// f64 -> bf16
66+
// FIXME: This pattern double rounds.
67+
def : Pat<(bf16 (fpround FPR64:$rs1)),
68+
(FCVT_BF16_S (FCVT_S_D FPR64:$rs1, FRM_DYN), FRM_DYN)>;
7569
// bf16 -> f64
7670
def : Pat<(fpextend (bf16 FPR16:$rs1)),
7771
(FCVT_D_S (FCVT_S_BF16 FPR16:$rs1, FRM_DYN), FRM_RNE)>;

0 commit comments

Comments
 (0)