@@ -1543,11 +1543,14 @@ static bool fitsInFPType(ConstantFP *CFP, const fltSemantics &Sem) {
1543
1543
return !losesInfo;
1544
1544
}
1545
1545
1546
- static Type *shrinkFPConstant (ConstantFP *CFP) {
1546
+ static Type *shrinkFPConstant (ConstantFP *CFP, bool PreferBFloat ) {
1547
1547
if (CFP->getType () == Type::getPPC_FP128Ty (CFP->getContext ()))
1548
1548
return nullptr ; // No constant folding of this.
1549
+ // See if the value can be truncated to bfloat and then reextended.
1550
+ if (PreferBFloat && fitsInFPType (CFP, APFloat::BFloat ()))
1551
+ return Type::getBFloatTy (CFP->getContext ());
1549
1552
// See if the value can be truncated to half and then reextended.
1550
- if (fitsInFPType (CFP, APFloat::IEEEhalf ()))
1553
+ if (!PreferBFloat && fitsInFPType (CFP, APFloat::IEEEhalf ()))
1551
1554
return Type::getHalfTy (CFP->getContext ());
1552
1555
// See if the value can be truncated to float and then reextended.
1553
1556
if (fitsInFPType (CFP, APFloat::IEEEsingle ()))
@@ -1562,7 +1565,7 @@ static Type *shrinkFPConstant(ConstantFP *CFP) {
1562
1565
1563
1566
// Determine if this is a vector of ConstantFPs and if so, return the minimal
1564
1567
// type we can safely truncate all elements to.
1565
- static Type *shrinkFPConstantVector (Value *V) {
1568
+ static Type *shrinkFPConstantVector (Value *V, bool PreferBFloat ) {
1566
1569
auto *CV = dyn_cast<Constant>(V);
1567
1570
auto *CVVTy = dyn_cast<FixedVectorType>(V->getType ());
1568
1571
if (!CV || !CVVTy)
@@ -1582,7 +1585,7 @@ static Type *shrinkFPConstantVector(Value *V) {
1582
1585
if (!CFP)
1583
1586
return nullptr ;
1584
1587
1585
- Type *T = shrinkFPConstant (CFP);
1588
+ Type *T = shrinkFPConstant (CFP, PreferBFloat );
1586
1589
if (!T)
1587
1590
return nullptr ;
1588
1591
@@ -1597,15 +1600,15 @@ static Type *shrinkFPConstantVector(Value *V) {
1597
1600
}
1598
1601
1599
1602
// / Find the minimum FP type we can safely truncate to.
1600
- static Type *getMinimumFPType (Value *V) {
1603
+ static Type *getMinimumFPType (Value *V, bool PreferBFloat ) {
1601
1604
if (auto *FPExt = dyn_cast<FPExtInst>(V))
1602
1605
return FPExt->getOperand (0 )->getType ();
1603
1606
1604
1607
// If this value is a constant, return the constant in the smallest FP type
1605
1608
// that can accurately represent it. This allows us to turn
1606
1609
// (float)((double)X+2.0) into x+2.0f.
1607
1610
if (auto *CFP = dyn_cast<ConstantFP>(V))
1608
- if (Type *T = shrinkFPConstant (CFP))
1611
+ if (Type *T = shrinkFPConstant (CFP, PreferBFloat ))
1609
1612
return T;
1610
1613
1611
1614
// We can only correctly find a minimum type for a scalable vector when it is
@@ -1617,7 +1620,7 @@ static Type *getMinimumFPType(Value *V) {
1617
1620
1618
1621
// Try to shrink a vector of FP constants. This returns nullptr on scalable
1619
1622
// vectors
1620
- if (Type *T = shrinkFPConstantVector (V))
1623
+ if (Type *T = shrinkFPConstantVector (V, PreferBFloat ))
1621
1624
return T;
1622
1625
1623
1626
return V->getType ();
@@ -1686,8 +1689,10 @@ Instruction *InstCombinerImpl::visitFPTrunc(FPTruncInst &FPT) {
1686
1689
Type *Ty = FPT.getType ();
1687
1690
auto *BO = dyn_cast<BinaryOperator>(FPT.getOperand (0 ));
1688
1691
if (BO && BO->hasOneUse ()) {
1689
- Type *LHSMinType = getMinimumFPType (BO->getOperand (0 ));
1690
- Type *RHSMinType = getMinimumFPType (BO->getOperand (1 ));
1692
+ Type *LHSMinType =
1693
+ getMinimumFPType (BO->getOperand (0 ), /* PreferBFloat=*/ Ty->isBFloatTy ());
1694
+ Type *RHSMinType =
1695
+ getMinimumFPType (BO->getOperand (1 ), /* PreferBFloat=*/ Ty->isBFloatTy ());
1691
1696
unsigned OpWidth = BO->getType ()->getFPMantissaWidth ();
1692
1697
unsigned LHSWidth = LHSMinType->getFPMantissaWidth ();
1693
1698
unsigned RHSWidth = RHSMinType->getFPMantissaWidth ();
0 commit comments