Skip to content

Commit cc13f3b

Browse files
authored
Correctly round FP -> BF16 when SDAG expands such nodes (llvm#82399)
We did something pretty naive: - round FP64 -> BF16 by first rounding to FP32 - skip FP32 -> BF16 rounding entirely - taking the top 16 bits of a FP32 which will turn some NaNs into infinities Let's do this in a more principled way by rounding types with more precision than FP32 to FP32 using round-inexact-to-odd which will negate double rounding issues.
1 parent cc374d8 commit cc13f3b

16 files changed

+17199
-4440
lines changed

llvm/include/llvm/CodeGen/TargetLowering.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5124,6 +5124,19 @@ class TargetLowering : public TargetLoweringBase {
51245124
/// \returns The expansion result
51255125
SDValue expandFP_TO_INT_SAT(SDNode *N, SelectionDAG &DAG) const;
51265126

5127+
/// Truncate Op to ResultVT. If the result is exact, leave it alone. If it is
5128+
/// not exact, force the result to be odd.
5129+
/// \param ResultVT The type of result.
5130+
/// \param Op The value to round.
5131+
/// \returns The expansion result
5132+
SDValue expandRoundInexactToOdd(EVT ResultVT, SDValue Op, const SDLoc &DL,
5133+
SelectionDAG &DAG) const;
5134+
5135+
/// Expand round(fp) to fp conversion
5136+
/// \param N Node to expand
5137+
/// \returns The expansion result
5138+
SDValue expandFP_ROUND(SDNode *Node, SelectionDAG &DAG) const;
5139+
51275140
/// Expand check for floating point class.
51285141
/// \param ResultVT The type of intrinsic call result.
51295142
/// \param Op The tested value.

llvm/lib/CodeGen/SelectionDAG/LegalizeDAG.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3217,10 +3217,8 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32173217
}
32183218
break;
32193219
case ISD::FP_ROUND: {
3220-
EVT VT = Node->getValueType(0);
3221-
if (VT.getScalarType() == MVT::bf16) {
3222-
Results.push_back(
3223-
DAG.getNode(ISD::FP_TO_BF16, SDLoc(Node), VT, Node->getOperand(0)));
3220+
if ((Tmp1 = TLI.expandFP_ROUND(Node, DAG))) {
3221+
Results.push_back(Tmp1);
32243222
break;
32253223
}
32263224

@@ -3293,6 +3291,10 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) {
32933291
if (Op.getValueType() != MVT::f32)
32943292
Op = DAG.getNode(ISD::FP_ROUND, dl, MVT::f32, Op,
32953293
DAG.getIntPtrConstant(0, dl, /*isTarget=*/true));
3294+
// Certain SNaNs will turn into infinities if we do a simple shift right.
3295+
if (!DAG.isKnownNeverSNaN(Op)) {
3296+
Op = DAG.getNode(ISD::FCANONICALIZE, dl, MVT::f32, Op, Node->getFlags());
3297+
}
32963298
Op = DAG.getNode(
32973299
ISD::SRL, dl, MVT::i32, DAG.getNode(ISD::BITCAST, dl, MVT::i32, Op),
32983300
DAG.getConstant(16, dl,

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10855,6 +10855,128 @@ SDValue TargetLowering::expandFP_TO_INT_SAT(SDNode *Node,
1085510855
return DAG.getSelect(dl, DstVT, IsNan, ZeroInt, Select);
1085610856
}
1085710857

10858+
SDValue TargetLowering::expandRoundInexactToOdd(EVT ResultVT, SDValue Op,
10859+
const SDLoc &dl,
10860+
SelectionDAG &DAG) const {
10861+
EVT OperandVT = Op.getValueType();
10862+
if (OperandVT.getScalarType() == ResultVT.getScalarType())
10863+
return Op;
10864+
EVT ResultIntVT = ResultVT.changeTypeToInteger();
10865+
// We are rounding binary64/binary128 -> binary32 -> bfloat16. This
10866+
// can induce double-rounding which may alter the results. We can
10867+
// correct for this using a trick explained in: Boldo, Sylvie, and
10868+
// Guillaume Melquiond. "When double rounding is odd." 17th IMACS
10869+
// World Congress. 2005.
10870+
unsigned BitSize = OperandVT.getScalarSizeInBits();
10871+
EVT WideIntVT = OperandVT.changeTypeToInteger();
10872+
SDValue OpAsInt = DAG.getBitcast(WideIntVT, Op);
10873+
SDValue SignBit =
10874+
DAG.getNode(ISD::AND, dl, WideIntVT, OpAsInt,
10875+
DAG.getConstant(APInt::getSignMask(BitSize), dl, WideIntVT));
10876+
SDValue AbsWide;
10877+
if (isOperationLegalOrCustom(ISD::FABS, OperandVT)) {
10878+
AbsWide = DAG.getNode(ISD::FABS, dl, OperandVT, Op);
10879+
} else {
10880+
SDValue ClearedSign = DAG.getNode(
10881+
ISD::AND, dl, WideIntVT, OpAsInt,
10882+
DAG.getConstant(APInt::getSignedMaxValue(BitSize), dl, WideIntVT));
10883+
AbsWide = DAG.getBitcast(OperandVT, ClearedSign);
10884+
}
10885+
SDValue AbsNarrow = DAG.getFPExtendOrRound(AbsWide, dl, ResultVT);
10886+
SDValue AbsNarrowAsWide = DAG.getFPExtendOrRound(AbsNarrow, dl, OperandVT);
10887+
10888+
// We can keep the narrow value as-is if narrowing was exact (no
10889+
// rounding error), the wide value was NaN (the narrow value is also
10890+
// NaN and should be preserved) or if we rounded to the odd value.
10891+
SDValue NarrowBits = DAG.getNode(ISD::BITCAST, dl, ResultIntVT, AbsNarrow);
10892+
SDValue One = DAG.getConstant(1, dl, ResultIntVT);
10893+
SDValue NegativeOne = DAG.getAllOnesConstant(dl, ResultIntVT);
10894+
SDValue And = DAG.getNode(ISD::AND, dl, ResultIntVT, NarrowBits, One);
10895+
EVT ResultIntVTCCVT = getSetCCResultType(
10896+
DAG.getDataLayout(), *DAG.getContext(), And.getValueType());
10897+
SDValue Zero = DAG.getConstant(0, dl, ResultIntVT);
10898+
SDValue AlreadyOdd = DAG.getSetCC(dl, ResultIntVTCCVT, And, Zero, ISD::SETNE);
10899+
10900+
EVT WideSetCCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
10901+
AbsWide.getValueType());
10902+
SDValue KeepNarrow =
10903+
DAG.getSetCC(dl, WideSetCCVT, AbsWide, AbsNarrowAsWide, ISD::SETUEQ);
10904+
KeepNarrow = DAG.getNode(ISD::OR, dl, WideSetCCVT, KeepNarrow, AlreadyOdd);
10905+
// We morally performed a round-down if `abs_narrow` is smaller than
10906+
// `abs_wide`.
10907+
SDValue NarrowIsRd =
10908+
DAG.getSetCC(dl, WideSetCCVT, AbsWide, AbsNarrowAsWide, ISD::SETOGT);
10909+
// If the narrow value is odd or exact, pick it.
10910+
// Otherwise, narrow is even and corresponds to either the rounded-up
10911+
// or rounded-down value. If narrow is the rounded-down value, we want
10912+
// the rounded-up value as it will be odd.
10913+
SDValue Adjust = DAG.getSelect(dl, ResultIntVT, NarrowIsRd, One, NegativeOne);
10914+
Adjust = DAG.getSelect(dl, ResultIntVT, KeepNarrow, Zero, Adjust);
10915+
int ShiftAmount = BitSize - ResultVT.getScalarSizeInBits();
10916+
SDValue ShiftCnst = DAG.getShiftAmountConstant(ShiftAmount, WideIntVT, dl);
10917+
SignBit = DAG.getNode(ISD::SRL, dl, WideIntVT, SignBit, ShiftCnst);
10918+
SignBit = DAG.getNode(ISD::TRUNCATE, dl, ResultIntVT, SignBit);
10919+
Op = DAG.getNode(ISD::OR, dl, ResultIntVT, Adjust, SignBit);
10920+
return DAG.getNode(ISD::BITCAST, dl, ResultVT, Op);
10921+
}
10922+
10923+
SDValue TargetLowering::expandFP_ROUND(SDNode *Node, SelectionDAG &DAG) const {
10924+
assert(Node->getOpcode() == ISD::FP_ROUND && "Unexpected opcode!");
10925+
SDValue Op = Node->getOperand(0);
10926+
EVT VT = Node->getValueType(0);
10927+
SDLoc dl(Node);
10928+
if (VT.getScalarType() == MVT::bf16) {
10929+
if (Node->getConstantOperandVal(1) == 1) {
10930+
return DAG.getNode(ISD::FP_TO_BF16, dl, VT, Node->getOperand(0));
10931+
}
10932+
EVT OperandVT = Op.getValueType();
10933+
SDValue IsNaN = DAG.getSetCC(
10934+
dl,
10935+
getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), OperandVT),
10936+
Op, Op, ISD::SETUO);
10937+
10938+
// We are rounding binary64/binary128 -> binary32 -> bfloat16. This
10939+
// can induce double-rounding which may alter the results. We can
10940+
// correct for this using a trick explained in: Boldo, Sylvie, and
10941+
// Guillaume Melquiond. "When double rounding is odd." 17th IMACS
10942+
// World Congress. 2005.
10943+
EVT F32 = VT.isVector() ? VT.changeVectorElementType(MVT::f32) : MVT::f32;
10944+
EVT I32 = F32.changeTypeToInteger();
10945+
Op = expandRoundInexactToOdd(F32, Op, dl, DAG);
10946+
Op = DAG.getNode(ISD::BITCAST, dl, I32, Op);
10947+
10948+
// Extract the sign bit.
10949+
SDValue SignBit =
10950+
DAG.getNode(ISD::AND, dl, I32, Op,
10951+
DAG.getConstant(APInt::getSignMask(32), dl, I32));
10952+
// Set the quiet bit.
10953+
SDValue NaN = DAG.getNode(ISD::OR, dl, I32, SignBit,
10954+
DAG.getConstant(0x400000, dl, I32));
10955+
10956+
// Factor in the contribution of the low 16 bits.
10957+
SDValue One = DAG.getConstant(1, dl, I32);
10958+
SDValue Lsb = DAG.getNode(ISD::SRL, dl, I32, Op,
10959+
DAG.getShiftAmountConstant(16, I32, dl));
10960+
Lsb = DAG.getNode(ISD::AND, dl, I32, Lsb, One);
10961+
SDValue RoundingBias =
10962+
DAG.getNode(ISD::ADD, dl, I32, DAG.getConstant(0x7fff, dl, I32), Lsb);
10963+
SDValue Add = DAG.getNode(ISD::ADD, dl, I32, Op, RoundingBias);
10964+
10965+
// Don't round if we had a NaN, we don't want to turn 0x7fffffff into
10966+
// 0x80000000.
10967+
Op = DAG.getSelect(dl, I32, IsNaN, NaN, Add);
10968+
10969+
// Now that we have rounded, shift the bits into position.
10970+
Op = DAG.getNode(ISD::SRL, dl, I32, Op,
10971+
DAG.getShiftAmountConstant(16, I32, dl));
10972+
Op = DAG.getNode(ISD::BITCAST, dl, I32, Op);
10973+
EVT I16 = I32.isVector() ? I32.changeVectorElementType(MVT::i16) : MVT::i16;
10974+
Op = DAG.getNode(ISD::TRUNCATE, dl, I16, Op);
10975+
return DAG.getNode(ISD::BITCAST, dl, VT, Op);
10976+
}
10977+
return SDValue();
10978+
}
10979+
1085810980
SDValue TargetLowering::expandVectorSplice(SDNode *Node,
1085910981
SelectionDAG &DAG) const {
1086010982
assert(Node->getOpcode() == ISD::VECTOR_SPLICE && "Unexpected opcode!");

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
776776
AddPromotedToType(Op, MVT::bf16, MVT::f32);
777777
}
778778

779+
if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 71) {
780+
setOperationAction(ISD::BF16_TO_FP, MVT::f32, Expand);
781+
}
782+
if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
783+
setOperationAction(ISD::FP_EXTEND, MVT::f64, Custom);
784+
setOperationAction(ISD::FP_ROUND, MVT::bf16, Custom);
785+
setOperationAction(ISD::BF16_TO_FP, MVT::f64, Custom);
786+
}
787+
779788
// sm_80 only has conversions between f32 and bf16. Custom lower all other
780789
// bf16 conversions.
781790
if (STI.hasBF16Math() &&
@@ -2465,6 +2474,72 @@ SDValue NVPTXTargetLowering::LowerFP_TO_INT(SDValue Op,
24652474
return Op;
24662475
}
24672476

2477+
SDValue NVPTXTargetLowering::LowerFP_ROUND(SDValue Op,
2478+
SelectionDAG &DAG) const {
2479+
EVT NarrowVT = Op.getValueType();
2480+
SDValue Wide = Op.getOperand(0);
2481+
EVT WideVT = Wide.getValueType();
2482+
if (NarrowVT.getScalarType() == MVT::bf16) {
2483+
const TargetLowering *TLI = STI.getTargetLowering();
2484+
if (STI.getSmVersion() < 80 || STI.getPTXVersion() < 70) {
2485+
return TLI->expandFP_ROUND(Op.getNode(), DAG);
2486+
}
2487+
if (STI.getSmVersion() < 90 || STI.getPTXVersion() < 78) {
2488+
// This combination was the first to support f32 -> bf16.
2489+
if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 70) {
2490+
if (WideVT.getScalarType() == MVT::f32) {
2491+
return Op;
2492+
}
2493+
if (WideVT.getScalarType() == MVT::f64) {
2494+
SDLoc Loc(Op);
2495+
// Round-inexact-to-odd f64 to f32, then do the final rounding using
2496+
// the hardware f32 -> bf16 instruction.
2497+
SDValue rod = TLI->expandRoundInexactToOdd(
2498+
WideVT.isVector() ? WideVT.changeVectorElementType(MVT::f32)
2499+
: MVT::f32,
2500+
Wide, Loc, DAG);
2501+
return DAG.getFPExtendOrRound(rod, Loc, NarrowVT);
2502+
}
2503+
}
2504+
return TLI->expandFP_ROUND(Op.getNode(), DAG);
2505+
}
2506+
}
2507+
2508+
// Everything else is considered legal.
2509+
return Op;
2510+
}
2511+
2512+
SDValue NVPTXTargetLowering::LowerFP_EXTEND(SDValue Op,
2513+
SelectionDAG &DAG) const {
2514+
SDValue Narrow = Op.getOperand(0);
2515+
EVT NarrowVT = Narrow.getValueType();
2516+
EVT WideVT = Op.getValueType();
2517+
if (NarrowVT.getScalarType() == MVT::bf16) {
2518+
if (WideVT.getScalarType() == MVT::f32 &&
2519+
(STI.getSmVersion() < 80 || STI.getPTXVersion() < 71)) {
2520+
SDLoc Loc(Op);
2521+
return DAG.getNode(ISD::BF16_TO_FP, Loc, WideVT, Narrow);
2522+
}
2523+
if (WideVT.getScalarType() == MVT::f64 &&
2524+
(STI.getSmVersion() < 90 || STI.getPTXVersion() < 78)) {
2525+
EVT F32 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f32)
2526+
: MVT::f32;
2527+
EVT F64 = NarrowVT.isVector() ? NarrowVT.changeVectorElementType(MVT::f64)
2528+
: MVT::f64;
2529+
SDLoc Loc(Op);
2530+
if (STI.getSmVersion() >= 80 && STI.getPTXVersion() >= 71) {
2531+
Op = DAG.getNode(ISD::FP_EXTEND, Loc, F32, Narrow);
2532+
} else {
2533+
Op = DAG.getNode(ISD::BF16_TO_FP, Loc, F32, Narrow);
2534+
}
2535+
return DAG.getNode(ISD::FP_EXTEND, Loc, F64, Op);
2536+
}
2537+
}
2538+
2539+
// Everything else is considered legal.
2540+
return Op;
2541+
}
2542+
24682543
static SDValue LowerVectorArith(SDValue Op, SelectionDAG &DAG) {
24692544
SDLoc DL(Op);
24702545
if (Op.getValueType() != MVT::v2i16)
@@ -2527,6 +2602,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
25272602
case ISD::FP_TO_SINT:
25282603
case ISD::FP_TO_UINT:
25292604
return LowerFP_TO_INT(Op, DAG);
2605+
case ISD::FP_ROUND:
2606+
return LowerFP_ROUND(Op, DAG);
2607+
case ISD::FP_EXTEND:
2608+
return LowerFP_EXTEND(Op, DAG);
25302609
case ISD::VAARG:
25312610
return LowerVAARG(Op, DAG);
25322611
case ISD::VASTART:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,9 @@ class NVPTXTargetLowering : public TargetLowering {
618618
SDValue LowerINT_TO_FP(SDValue Op, SelectionDAG &DAG) const;
619619
SDValue LowerFP_TO_INT(SDValue Op, SelectionDAG &DAG) const;
620620

621+
SDValue LowerFP_ROUND(SDValue Op, SelectionDAG &DAG) const;
622+
SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
623+
621624
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
622625
SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
623626

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -662,7 +662,7 @@ let hasSideEffects = false in {
662662
// bf16->f32 was introduced early.
663663
[hasPTX<71>, hasSM<80>],
664664
// bf16->everything else needs sm90/ptx78
665-
[hasPTX<78>, hasSM<90>])>;
665+
[hasPTX<78>, hasSM<90>])>;
666666
def _f32 :
667667
NVPTXInst<(outs RC:$dst),
668668
(ins Float32Regs:$src, CvtMode:$mode),
@@ -3647,15 +3647,15 @@ def : Pat<(f16 (fpround Float32Regs:$a)),
36473647

36483648
// fpround f32 -> bf16
36493649
def : Pat<(bf16 (fpround Float32Regs:$a)),
3650-
(CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
3650+
(CVT_bf16_f32 Float32Regs:$a, CvtRN)>, Requires<[hasPTX<70>, hasSM<80>]>;
36513651

36523652
// fpround f64 -> f16
36533653
def : Pat<(f16 (fpround Float64Regs:$a)),
36543654
(CVT_f16_f64 Float64Regs:$a, CvtRN)>;
36553655

36563656
// fpround f64 -> bf16
36573657
def : Pat<(bf16 (fpround Float64Regs:$a)),
3658-
(CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
3658+
(CVT_bf16_f64 Float64Regs:$a, CvtRN)>, Requires<[hasPTX<78>, hasSM<90>]>;
36593659
// fpround f64 -> f32
36603660
def : Pat<(f32 (fpround Float64Regs:$a)),
36613661
(CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3671,15 +3671,15 @@ def : Pat<(f32 (fpextend (f16 Int16Regs:$a))),
36713671
def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
36723672
(CVT_f32_bf16 Int16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
36733673
def : Pat<(f32 (fpextend (bf16 Int16Regs:$a))),
3674-
(CVT_f32_bf16 Int16Regs:$a, CvtNONE)>;
3674+
(CVT_f32_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<71>, hasSM<80>]>;
36753675

36763676
// fpextend f16 -> f64
36773677
def : Pat<(f64 (fpextend (f16 Int16Regs:$a))),
36783678
(CVT_f64_f16 Int16Regs:$a, CvtNONE)>;
36793679

36803680
// fpextend bf16 -> f64
36813681
def : Pat<(f64 (fpextend (bf16 Int16Regs:$a))),
3682-
(CVT_f64_bf16 Int16Regs:$a, CvtNONE)>;
3682+
(CVT_f64_bf16 Int16Regs:$a, CvtNONE)>, Requires<[hasPTX<78>, hasSM<90>]>;
36833683

36843684
// fpextend f32 -> f64
36853685
def : Pat<(f64 (fpextend Float32Regs:$a)),

0 commit comments

Comments
 (0)