Skip to content

[SelectionDAG] Add STRICT_BF16_TO_FP and STRICT_FP_TO_BF16 #80056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions compiler-rt/lib/builtins/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ set(GENERIC_SOURCES

# We only build BF16 files when "__bf16" is available.
set(BF16_SOURCES
extendbfsf2.c
truncdfbf2.c
truncsfbf2.c
)
Expand Down
13 changes: 13 additions & 0 deletions compiler-rt/lib/builtins/extendbfsf2.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
//===-- lib/extendbfsf2.c - bfloat -> single conversion -----------*- C -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#define SRC_BFLOAT
#define DST_SINGLE
#include "fp_extend_impl.inc"

COMPILER_RT_ABI float __extendbfsf2(src_t a) { return __extendXfYf2__(a); }
7 changes: 7 additions & 0 deletions compiler-rt/lib/builtins/fp_extend.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,13 @@ static inline int src_rep_t_clz_impl(src_rep_t a) {

#define src_rep_t_clz src_rep_t_clz_impl

#elif defined SRC_BFLOAT
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in compiler-rt were from https://reviews.llvm.org/D151436.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt whether the implementation handles exception or not. But I don't have any suggestion to it.

typedef __bf16 src_t;
typedef uint16_t src_rep_t;
#define SRC_REP_C UINT16_C
static const int srcSigBits = 7;
#define src_rep_t_clz __builtin_clz

#else
#error Source should be half, single, or double precision!
#endif // end source precision
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/ISDOpcodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -921,6 +921,8 @@ enum NodeType {
/// has native conversions.
BF16_TO_FP,
FP_TO_BF16,
STRICT_BF16_TO_FP,
STRICT_FP_TO_BF16,

/// Perform various unary floating-point operations inspired by libm. For
/// FPOWI, the result is undefined if the integer operand doesn't fit into
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/CodeGen/SelectionDAGNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,8 @@ END_TWO_BYTE_PACK()
return false;
case ISD::STRICT_FP16_TO_FP:
case ISD::STRICT_FP_TO_FP16:
case ISD::STRICT_BF16_TO_FP:
case ISD::STRICT_FP_TO_BF16:
#define DAG_INSTRUCTION(NAME, NARG, ROUND_MODE, INTRINSIC, DAGN) \
case ISD::STRICT_##DAGN:
#include "llvm/IR/ConstrainedOps.def"
Expand Down
1 change: 1 addition & 0 deletions llvm/include/llvm/IR/RuntimeLibcalls.def
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ HANDLE_LIBCALL(FEGETMODE, "fegetmode")
HANDLE_LIBCALL(FESETMODE, "fesetmode")

// Conversion
HANDLE_LIBCALL(FPEXT_BF16_F32, "__extendbfsf2")
HANDLE_LIBCALL(FPEXT_F32_PPCF128, "__gcc_stoq")
HANDLE_LIBCALL(FPEXT_F64_PPCF128, "__gcc_dtoq")
HANDLE_LIBCALL(FPEXT_F80_F128, "__extendxftf2")
Expand Down
13 changes: 13 additions & 0 deletions llvm/include/llvm/Target/TargetSelectionDAG.td
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,8 @@ def fp_to_sint_sat : SDNode<"ISD::FP_TO_SINT_SAT" , SDTFPToIntSatOp>;
def fp_to_uint_sat : SDNode<"ISD::FP_TO_UINT_SAT" , SDTFPToIntSatOp>;
def f16_to_fp : SDNode<"ISD::FP16_TO_FP" , SDTIntToFPOp>;
def fp_to_f16 : SDNode<"ISD::FP_TO_FP16" , SDTFPToIntOp>;
def bf16_to_fp : SDNode<"ISD::BF16_TO_FP" , SDTIntToFPOp>;
def fp_to_bf16 : SDNode<"ISD::FP_TO_BF16" , SDTFPToIntOp>;

def strict_fadd : SDNode<"ISD::STRICT_FADD",
SDTFPBinOp, [SDNPHasChain, SDNPCommutative]>;
Expand Down Expand Up @@ -620,6 +622,11 @@ def strict_f16_to_fp : SDNode<"ISD::STRICT_FP16_TO_FP",
def strict_fp_to_f16 : SDNode<"ISD::STRICT_FP_TO_FP16",
SDTFPToIntOp, [SDNPHasChain]>;

def strict_bf16_to_fp : SDNode<"ISD::STRICT_BF16_TO_FP",
SDTIntToFPOp, [SDNPHasChain]>;
def strict_fp_to_bf16 : SDNode<"ISD::STRICT_FP_TO_BF16",
SDTFPToIntOp, [SDNPHasChain]>;

def strict_fsetcc : SDNode<"ISD::STRICT_FSETCC", SDTSetCC, [SDNPHasChain]>;
def strict_fsetccs : SDNode<"ISD::STRICT_FSETCCS", SDTSetCC, [SDNPHasChain]>;

Expand Down Expand Up @@ -1591,6 +1598,12 @@ def any_f16_to_fp : PatFrags<(ops node:$src),
def any_fp_to_f16 : PatFrags<(ops node:$src),
[(fp_to_f16 node:$src),
(strict_fp_to_f16 node:$src)]>;
def any_bf16_to_fp : PatFrags<(ops node:$src),
[(bf16_to_fp node:$src),
(strict_bf16_to_fp node:$src)]>;
def any_fp_to_bf16 : PatFrags<(ops node:$src),
[(fp_to_bf16 node:$src),
(strict_fp_to_bf16 node:$src)]>;

multiclass binary_atomic_op_ord {
def NAME#_monotonic : PatFrag<(ops node:$ptr, node:$val),
Expand Down
34 changes: 25 additions & 9 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ void SelectionDAGLegalize::LegalizeOp(SDNode *Node) {
Node->getOperand(0).getValueType());
break;
case ISD::STRICT_FP_TO_FP16:
case ISD::STRICT_FP_TO_BF16:
case ISD::STRICT_SINT_TO_FP:
case ISD::STRICT_UINT_TO_FP:
case ISD::STRICT_LRINT:
Expand Down Expand Up @@ -3645,14 +3646,14 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
DAG.getNode(ISD::FP_EXTEND, dl, Node->getValueType(0), Res));
}
break;
case ISD::STRICT_BF16_TO_FP:
case ISD::STRICT_FP16_TO_FP:
if (Node->getValueType(0) != MVT::f32) {
// We can extend to types bigger than f32 in two steps without changing
// the result. Since "f16 -> f32" is much more commonly available, give
// CodeGen the option of emitting that before resorting to a libcall.
SDValue Res =
DAG.getNode(ISD::STRICT_FP16_TO_FP, dl, {MVT::f32, MVT::Other},
{Node->getOperand(0), Node->getOperand(1)});
SDValue Res = DAG.getNode(Node->getOpcode(), dl, {MVT::f32, MVT::Other},
{Node->getOperand(0), Node->getOperand(1)});
Res = DAG.getNode(ISD::STRICT_FP_EXTEND, dl,
{Node->getValueType(0), MVT::Other},
{Res.getValue(1), Res});
Expand Down Expand Up @@ -4651,6 +4652,16 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
Results.push_back(ExpandLibCall(RTLIB::FPEXT_F16_F32, Node, false).first);
}
break;
case ISD::STRICT_BF16_TO_FP:
if (Node->getValueType(0) == MVT::f32) {
TargetLowering::MakeLibCallOptions CallOptions;
std::pair<SDValue, SDValue> Tmp = TLI.makeLibCall(
DAG, RTLIB::FPEXT_BF16_F32, MVT::f32, Node->getOperand(1),
CallOptions, SDLoc(Node), Node->getOperand(0));
Results.push_back(Tmp.first);
Results.push_back(Tmp.second);
}
break;
case ISD::STRICT_FP16_TO_FP: {
if (Node->getValueType(0) == MVT::f32) {
TargetLowering::MakeLibCallOptions CallOptions;
Expand Down Expand Up @@ -4792,12 +4803,17 @@ void SelectionDAGLegalize::ConvertNodeToLibcall(SDNode *Node) {
break;
}
case ISD::STRICT_FP_EXTEND:
case ISD::STRICT_FP_TO_FP16: {
RTLIB::Libcall LC =
Node->getOpcode() == ISD::STRICT_FP_TO_FP16
? RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16)
: RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
Node->getValueType(0));
case ISD::STRICT_FP_TO_FP16:
case ISD::STRICT_FP_TO_BF16: {
RTLIB::Libcall LC = RTLIB::UNKNOWN_LIBCALL;
if (Node->getOpcode() == ISD::STRICT_FP_TO_FP16)
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::f16);
else if (Node->getOpcode() == ISD::STRICT_FP_TO_BF16)
LC = RTLIB::getFPROUND(Node->getOperand(1).getValueType(), MVT::bf16);
else
LC = RTLIB::getFPEXT(Node->getOperand(1).getValueType(),
Node->getValueType(0));

assert(LC != RTLIB::UNKNOWN_LIBCALL && "Unable to legalize as libcall");

TargetLowering::MakeLibCallOptions CallOptions;
Expand Down
41 changes: 27 additions & 14 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeFloatTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,6 +918,7 @@ bool DAGTypeLegalizer::SoftenFloatOperand(SDNode *N, unsigned OpNo) {
case ISD::STRICT_FP_TO_FP16:
case ISD::FP_TO_FP16: // Same as FP_ROUND for softening purposes
case ISD::FP_TO_BF16:
case ISD::STRICT_FP_TO_BF16:
case ISD::STRICT_FP_ROUND:
case ISD::FP_ROUND: Res = SoftenFloatOp_FP_ROUND(N); break;
case ISD::STRICT_FP_TO_SINT:
Expand Down Expand Up @@ -970,6 +971,7 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
assert(N->getOpcode() == ISD::FP_ROUND || N->getOpcode() == ISD::FP_TO_FP16 ||
N->getOpcode() == ISD::STRICT_FP_TO_FP16 ||
N->getOpcode() == ISD::FP_TO_BF16 ||
N->getOpcode() == ISD::STRICT_FP_TO_BF16 ||
N->getOpcode() == ISD::STRICT_FP_ROUND);

bool IsStrict = N->isStrictFPOpcode();
Expand All @@ -980,7 +982,8 @@ SDValue DAGTypeLegalizer::SoftenFloatOp_FP_ROUND(SDNode *N) {
if (N->getOpcode() == ISD::FP_TO_FP16 ||
N->getOpcode() == ISD::STRICT_FP_TO_FP16)
FloatRVT = MVT::f16;
else if (N->getOpcode() == ISD::FP_TO_BF16)
else if (N->getOpcode() == ISD::FP_TO_BF16 ||
N->getOpcode() == ISD::STRICT_FP_TO_BF16)
FloatRVT = MVT::bf16;

RTLIB::Libcall LC = RTLIB::getFPROUND(SVT, FloatRVT);
Expand Down Expand Up @@ -2193,13 +2196,11 @@ static ISD::NodeType GetPromotionOpcodeStrict(EVT OpVT, EVT RetVT) {
if (RetVT == MVT::f16)
return ISD::STRICT_FP_TO_FP16;

if (OpVT == MVT::bf16) {
// TODO: return ISD::STRICT_BF16_TO_FP;
}
if (OpVT == MVT::bf16)
return ISD::STRICT_BF16_TO_FP;

if (RetVT == MVT::bf16) {
// TODO: return ISD::STRICT_FP_TO_BF16;
}
if (RetVT == MVT::bf16)
return ISD::STRICT_FP_TO_BF16;

report_fatal_error("Attempt at an invalid promotion-related conversion");
}
Expand Down Expand Up @@ -2999,10 +3000,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfRes_FP_ROUND(SDNode *N) {
EVT SVT = N->getOperand(0).getValueType();

if (N->isStrictFPOpcode()) {
assert(RVT == MVT::f16);
SDValue Res =
DAG.getNode(ISD::STRICT_FP_TO_FP16, SDLoc(N), {MVT::i16, MVT::Other},
{N->getOperand(0), N->getOperand(1)});
// FIXME: assume we only have two f16 variants for now.
unsigned Opcode;
if (RVT == MVT::f16)
Opcode = ISD::STRICT_FP_TO_FP16;
else if (RVT == MVT::bf16)
Opcode = ISD::STRICT_FP_TO_BF16;
else
llvm_unreachable("unknown half type");
SDValue Res = DAG.getNode(Opcode, SDLoc(N), {MVT::i16, MVT::Other},
{N->getOperand(0), N->getOperand(1)});
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
return Res;
}
Expand Down Expand Up @@ -3192,10 +3199,16 @@ SDValue DAGTypeLegalizer::SoftPromoteHalfOp_FP_EXTEND(SDNode *N) {
Op = GetSoftPromotedHalf(N->getOperand(IsStrict ? 1 : 0));

if (IsStrict) {
assert(SVT == MVT::f16);
unsigned Opcode;
if (SVT == MVT::f16)
Opcode = ISD::STRICT_FP16_TO_FP;
else if (SVT == MVT::bf16)
Opcode = ISD::STRICT_BF16_TO_FP;
else
llvm_unreachable("unknown half type");
SDValue Res =
DAG.getNode(ISD::STRICT_FP16_TO_FP, SDLoc(N),
{N->getValueType(0), MVT::Other}, {N->getOperand(0), Op});
DAG.getNode(Opcode, SDLoc(N), {N->getValueType(0), MVT::Other},
{N->getOperand(0), Op});
ReplaceValueWith(SDValue(N, 1), Res.getValue(1));
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::FP_TO_FP16:
Res = PromoteIntRes_FP_TO_FP16_BF16(N);
break;
case ISD::STRICT_FP_TO_BF16:
case ISD::STRICT_FP_TO_FP16:
Res = PromoteIntRes_STRICT_FP_TO_FP16_BF16(N);
break;
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ std::string SDNode::getOperationName(const SelectionDAG *G) const {
case ISD::FP_TO_FP16: return "fp_to_fp16";
case ISD::STRICT_FP_TO_FP16: return "strict_fp_to_fp16";
case ISD::BF16_TO_FP: return "bf16_to_fp";
case ISD::STRICT_BF16_TO_FP: return "strict_bf16_to_fp";
case ISD::FP_TO_BF16: return "fp_to_bf16";
case ISD::STRICT_FP_TO_BF16: return "strict_fp_to_bf16";
case ISD::LROUND: return "lround";
case ISD::STRICT_LROUND: return "strict_lround";
case ISD::LLROUND: return "llround";
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/CodeGen/TargetLoweringBase.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,9 @@ RTLIB::Libcall RTLIB::getFPEXT(EVT OpVT, EVT RetVT) {
} else if (OpVT == MVT::f80) {
if (RetVT == MVT::f128)
return FPEXT_F80_F128;
} else if (OpVT == MVT::bf16) {
if (RetVT == MVT::f32)
return FPEXT_BF16_F32;
}

return UNKNOWN_LIBCALL;
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,11 @@ X86TargetLowering::X86TargetLowering(const X86TargetMachine &TM,
setOperationAction(Op, MVT::f128, Expand);
}

for (auto VT : {MVT::f32, MVT::f64, MVT::f80, MVT::f128}) {
setOperationAction(ISD::STRICT_FP_TO_BF16, VT, Expand);
setOperationAction(ISD::STRICT_BF16_TO_FP, VT, Expand);
}

for (MVT VT : {MVT::f32, MVT::f64, MVT::f80, MVT::f128}) {
setLoadExtAction(ISD::EXTLOAD, VT, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, VT, MVT::bf16, Expand);
Expand Down
Loading