Skip to content

Commit d3f6dd6

Browse files
authored
[InstCombine] Pick bfloat over half when shrinking ops that started with an fpext from bfloat (#82493)
This fixes the case where we would shrink an frem to half and then bitcast to bfloat, producing invalid results. The transformation was written under the assumption that there is only one type with a given bit width. Also add a strategic assert to CastInst::CreateFPCast to turn this miscompilation into a crash.
1 parent 88e31f6 commit d3f6dd6

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

llvm/lib/IR/Instructions.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3525,6 +3525,7 @@ CastInst *CastInst::CreateFPCast(Value *C, Type *Ty,
35253525
"Invalid cast");
35263526
unsigned SrcBits = C->getType()->getScalarSizeInBits();
35273527
unsigned DstBits = Ty->getScalarSizeInBits();
3528+
assert((C->getType() == Ty || SrcBits != DstBits) && "Invalid cast");
35283529
Instruction::CastOps opcode =
35293530
(SrcBits == DstBits ? Instruction::BitCast :
35303531
(SrcBits > DstBits ? Instruction::FPTrunc : Instruction::FPExt));

llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,11 +1543,14 @@ static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
15431543
return !losesInfo;
15441544
}
15451545

1546-
static Type *shrinkFPConstant(ConstantFP *CFP) {
1546+
static Type *shrinkFPConstant(ConstantFP *CFP, bool PreferBFloat) {
15471547
if (CFP->getType() == Type::getPPC_FP128Ty(CFP->getContext()))
15481548
return nullptr; // No constant folding of this.
1549+
// See if the value can be truncated to bfloat and then reextended.
1550+
if (PreferBFloat && fitsInFPType(CFP, APFloat::BFloat()))
1551+
return Type::getBFloatTy(CFP->getContext());
15491552
// See if the value can be truncated to half and then reextended.
1550-
if (fitsInFPType(CFP, APFloat::IEEEhalf()))
1553+
if (!PreferBFloat && fitsInFPType(CFP, APFloat::IEEEhalf()))
15511554
return Type::getHalfTy(CFP->getContext());
15521555
// See if the value can be truncated to float and then reextended.
15531556
if (fitsInFPType(CFP, APFloat::IEEEsingle()))
@@ -1562,7 +1565,7 @@ static Type *shrinkFPConstant(ConstantFP *CFP) {
15621565

15631566
// Determine if this is a vector of ConstantFPs and if so, return the minimal
15641567
// type we can safely truncate all elements to.
1565-
static Type *shrinkFPConstantVector(Value *V) {
1568+
static Type *shrinkFPConstantVector(Value *V, bool PreferBFloat) {
15661569
auto *CV = dyn_cast<Constant>(V);
15671570
auto *CVVTy = dyn_cast<FixedVectorType>(V->getType());
15681571
if (!CV || !CVVTy)
@@ -1582,7 +1585,7 @@ static Type *shrinkFPConstantVector(Value *V) {
15821585
if (!CFP)
15831586
return nullptr;
15841587

1585-
Type *T = shrinkFPConstant(CFP);
1588+
Type *T = shrinkFPConstant(CFP, PreferBFloat);
15861589
if (!T)
15871590
return nullptr;
15881591

@@ -1597,15 +1600,15 @@ static Type *shrinkFPConstantVector(Value *V) {
15971600
}
15981601

15991602
/// Find the minimum FP type we can safely truncate to.
1600-
static Type *getMinimumFPType(Value *V) {
1603+
static Type *getMinimumFPType(Value *V, bool PreferBFloat) {
16011604
if (auto *FPExt = dyn_cast<FPExtInst>(V))
16021605
return FPExt->getOperand(0)->getType();
16031606

16041607
// If this value is a constant, return the constant in the smallest FP type
16051608
// that can accurately represent it. This allows us to turn
16061609
// (float)((double)X+2.0) into x+2.0f.
16071610
if (auto *CFP = dyn_cast<ConstantFP>(V))
1608-
if (Type *T = shrinkFPConstant(CFP))
1611+
if (Type *T = shrinkFPConstant(CFP, PreferBFloat))
16091612
return T;
16101613

16111614
// We can only correctly find a minimum type for a scalable vector when it is
@@ -1617,7 +1620,7 @@ static Type *getMinimumFPType(Value *V) {
16171620

16181621
// Try to shrink a vector of FP constants. This returns nullptr on scalable
16191622
// vectors
1620-
if (Type *T = shrinkFPConstantVector(V))
1623+
if (Type *T = shrinkFPConstantVector(V, PreferBFloat))
16211624
return T;
16221625

16231626
return V->getType();
@@ -1686,8 +1689,10 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
16861689
Type *Ty = FPT.getType();
16871690
auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand(0));
16881691
if (BO && BO->hasOneUse()) {
1689-
Type *LHSMinType = getMinimumFPType(BO->getOperand(0));
1690-
Type *RHSMinType = getMinimumFPType(BO->getOperand(1));
1692+
Type *LHSMinType =
1693+
getMinimumFPType(BO->getOperand(0), /*PreferBFloat=*/Ty->isBFloatTy());
1694+
Type *RHSMinType =
1695+
getMinimumFPType(BO->getOperand(1), /*PreferBFloat=*/Ty->isBFloatTy());
16911696
unsigned OpWidth = BO->getType()->getFPMantissaWidth();
16921697
unsigned LHSWidth = LHSMinType->getFPMantissaWidth();
16931698
unsigned RHSWidth = RHSMinType->getFPMantissaWidth();

llvm/test/Transforms/InstCombine/fpextend.ll

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,3 +437,14 @@ define half @bf16_to_f32_to_f16(bfloat %a) nounwind {
437437
%z = fptrunc float %y to half
438438
ret half %z
439439
}
440+
441+
define bfloat @bf16_frem(bfloat %x) {
442+
; CHECK-LABEL: @bf16_frem(
443+
; CHECK-NEXT: [[FREM:%.*]] = frem bfloat [[X:%.*]], 0xR40C9
444+
; CHECK-NEXT: ret bfloat [[FREM]]
445+
;
446+
%t1 = fpext bfloat %x to float
447+
%t2 = frem float %t1, 6.281250e+00
448+
%t3 = fptrunc float %t2 to bfloat
449+
ret bfloat %t3
450+
}

0 commit comments

Comments
 (0)