Skip to content

Commit 63d6e24

Browse files
committed
[APFloat] Add APFloat support for E8M0 type
This patch adds an APFloat type for unsigned E8M0 format. This format is used for representing the "scale-format" in the MX specification: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf This format does not support {Inf, denorms, zeroes}. Like FP32, this format's exponents are 8-bits (all bits here) and the bias value is 127. However, it differs from IEEE-FP32 in that the minExponent is -127 (instead of -126). There are updates done in the APFloat utility functions to handle these constraints for this format. * The bias calculation is different and convertIEEE* APIs are updated to handle this. * Since there are no significand bits, the isSignificandAll{Zeroes/Ones} methods are updated accordingly. * Although the format does not have any precision, the precision bit in the fltSemantics is set to 1 for consistency with APFloat's internal representation. * Many utility functions are updated to handle the fact that this format does not support Zero. * Provide a separate initFromAPInt() implementation to handle the quirks of the format. * Add specific tests to verify the range of values for this format. Signed-off-by: Durgadoss R <[email protected]>
1 parent a8e1c6f commit 63d6e24

File tree

3 files changed

+374
-21
lines changed

3 files changed

+374
-21
lines changed

llvm/include/llvm/ADT/APFloat.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,12 @@ struct APFloatBase {
195195
// improved range compared to half (16-bit) formats, at (potentially)
196196
// greater throughput than single precision (32-bit) formats.
197197
S_FloatTF32,
198+
// 8-bit floating point number with (all the) 8 bits for the exponent
199+
// like in FP32. There are no zeroes, no infinities, and no denormal values.
200+
// NaN is represented with all bits set to 1. Bias is 127.
201+
// This represents the scale data type in the MX specification from
202+
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
203+
S_Float8E8M0FN,
198204
// 6-bit floating point number with bit layout S1E3M2. Unlike IEEE-754
199205
// types, there are no infinity or NaN values. The format is detailed in
200206
// https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
@@ -229,6 +235,7 @@ struct APFloatBase {
229235
static const fltSemantics &Float8E4M3B11FNUZ() LLVM_READNONE;
230236
static const fltSemantics &Float8E3M4() LLVM_READNONE;
231237
static const fltSemantics &FloatTF32() LLVM_READNONE;
238+
static const fltSemantics &Float8E8M0FN() LLVM_READNONE;
232239
static const fltSemantics &Float6E3M2FN() LLVM_READNONE;
233240
static const fltSemantics &Float6E2M3FN() LLVM_READNONE;
234241
static const fltSemantics &Float4E2M1FN() LLVM_READNONE;
@@ -652,6 +659,7 @@ class IEEEFloat final : public APFloatBase {
652659
APInt convertFloat8E4M3B11FNUZAPFloatToAPInt() const;
653660
APInt convertFloat8E3M4APFloatToAPInt() const;
654661
APInt convertFloatTF32APFloatToAPInt() const;
662+
APInt convertFloat8E8M0FNAPFloatToAPInt() const;
655663
APInt convertFloat6E3M2FNAPFloatToAPInt() const;
656664
APInt convertFloat6E2M3FNAPFloatToAPInt() const;
657665
APInt convertFloat4E2M1FNAPFloatToAPInt() const;
@@ -672,6 +680,7 @@ class IEEEFloat final : public APFloatBase {
672680
void initFromFloat8E4M3B11FNUZAPInt(const APInt &api);
673681
void initFromFloat8E3M4APInt(const APInt &api);
674682
void initFromFloatTF32APInt(const APInt &api);
683+
void initFromFloat8E8M0FNAPInt(const APInt &api);
675684
void initFromFloat6E3M2FNAPInt(const APInt &api);
676685
void initFromFloat6E2M3FNAPInt(const APInt &api);
677686
void initFromFloat4E2M1FNAPInt(const APInt &api);
@@ -1091,6 +1100,26 @@ class APFloat : public APFloatBase {
10911100
}
10921101
}
10931102

1103+
static bool hasZero(const fltSemantics &Sem) {
1104+
switch (SemanticsToEnum(Sem)) {
1105+
default:
1106+
return true;
1107+
// The Float8E8M0FN does not have an encoding for Zeroes.
1108+
case APFloat::S_Float8E8M0FN:
1109+
return false;
1110+
}
1111+
}
1112+
1113+
static bool hasExponentOnly(const fltSemantics &Sem) {
1114+
switch (SemanticsToEnum(Sem)) {
1115+
default:
1116+
return false;
1117+
// The Float8E8M0FN has exponent only and no significand.
1118+
case APFloat::S_Float8E8M0FN:
1119+
return true;
1120+
}
1121+
}
1122+
10941123
/// Used to insert APFloat objects, or objects that contain APFloat objects,
10951124
/// into FoldingSets.
10961125
void Profile(FoldingSetNodeID &NID) const;

llvm/lib/Support/APFloat.cpp

Lines changed: 124 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,9 @@ static constexpr fltSemantics semFloat8E4M3B11FNUZ = {
145145
4, -10, 4, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::NegativeZero};
146146
static constexpr fltSemantics semFloat8E3M4 = {3, -2, 5, 8};
147147
static constexpr fltSemantics semFloatTF32 = {127, -126, 11, 19};
148+
static constexpr fltSemantics semFloat8E8M0FN = {
149+
127, -127, 1, 8, fltNonfiniteBehavior::NanOnly, fltNanEncoding::AllOnes};
150+
148151
static constexpr fltSemantics semFloat6E3M2FN = {
149152
4, -2, 3, 6, fltNonfiniteBehavior::FiniteOnly};
150153
static constexpr fltSemantics semFloat6E2M3FN = {
@@ -222,6 +225,8 @@ const llvm::fltSemantics &APFloatBase::EnumToSemantics(Semantics S) {
222225
return Float8E3M4();
223226
case S_FloatTF32:
224227
return FloatTF32();
228+
case S_Float8E8M0FN:
229+
return Float8E8M0FN();
225230
case S_Float6E3M2FN:
226231
return Float6E3M2FN();
227232
case S_Float6E2M3FN:
@@ -264,6 +269,8 @@ APFloatBase::SemanticsToEnum(const llvm::fltSemantics &Sem) {
264269
return S_Float8E3M4;
265270
else if (&Sem == &llvm::APFloat::FloatTF32())
266271
return S_FloatTF32;
272+
else if (&Sem == &llvm::APFloat::Float8E8M0FN())
273+
return S_Float8E8M0FN;
267274
else if (&Sem == &llvm::APFloat::Float6E3M2FN())
268275
return S_Float6E3M2FN;
269276
else if (&Sem == &llvm::APFloat::Float6E2M3FN())
@@ -294,6 +301,7 @@ const fltSemantics &APFloatBase::Float8E4M3B11FNUZ() {
294301
}
295302
const fltSemantics &APFloatBase::Float8E3M4() { return semFloat8E3M4; }
296303
const fltSemantics &APFloatBase::FloatTF32() { return semFloatTF32; }
304+
const fltSemantics &APFloatBase::Float8E8M0FN() { return semFloat8E8M0FN; }
297305
const fltSemantics &APFloatBase::Float6E3M2FN() { return semFloat6E3M2FN; }
298306
const fltSemantics &APFloatBase::Float6E2M3FN() { return semFloat6E2M3FN; }
299307
const fltSemantics &APFloatBase::Float4E2M1FN() { return semFloat4E2M1FN; }
@@ -396,6 +404,8 @@ static inline Error createError(const Twine &Err) {
396404
}
397405

398406
static constexpr inline unsigned int partCountForBits(unsigned int bits) {
407+
if (bits == 0)
408+
return 1;
399409
return ((bits) + APFloatBase::integerPartWidth - 1) / APFloatBase::integerPartWidth;
400410
}
401411

@@ -955,6 +965,12 @@ void IEEEFloat::makeNaN(bool SNaN, bool Negative, const APInt *fill) {
955965
significand[part] = 0;
956966
}
957967

968+
// For the E8M0 types, precision is just 1 and the
969+
// the NaNBit handling below is not relevant.
970+
// So, exit early.
971+
if (semantics == &semFloat8E8M0FN)
972+
return;
973+
958974
unsigned QNaNBit = semantics->precision - 2;
959975

960976
if (SNaN) {
@@ -1007,6 +1023,10 @@ IEEEFloat &IEEEFloat::operator=(IEEEFloat &&rhs) {
10071023
}
10081024

10091025
bool IEEEFloat::isDenormal() const {
1026+
// No denormals in Float8E8M0FN
1027+
if (semantics == &semFloat8E8M0FN)
1028+
return false;
1029+
10101030
return isFiniteNonZero() && (exponent == semantics->minExponent) &&
10111031
(APInt::tcExtractBit(significandParts(),
10121032
semantics->precision - 1) == 0);
@@ -1028,6 +1048,10 @@ bool IEEEFloat::isSmallestNormalized() const {
10281048
bool IEEEFloat::isSignificandAllOnes() const {
10291049
// Test if the significand excluding the integral bit is all ones. This allows
10301050
// us to test for binade boundaries.
1051+
// For the E8M0 format, this is always false since there are no
1052+
// actual significand bits.
1053+
if (semantics == &semFloat8E8M0FN)
1054+
return false;
10311055
const integerPart *Parts = significandParts();
10321056
const unsigned PartCount = partCountForBits(semantics->precision);
10331057
for (unsigned i = 0; i < PartCount - 1; i++)
@@ -1075,6 +1099,11 @@ bool IEEEFloat::isSignificandAllOnesExceptLSB() const {
10751099
}
10761100

10771101
bool IEEEFloat::isSignificandAllZeros() const {
1102+
// For the E8M0 format, this is always true since there are no
1103+
// actual significand bits.
1104+
if (semantics == &semFloat8E8M0FN)
1105+
return true;
1106+
10781107
// Test if the significand excluding the integral bit is all zeros. This
10791108
// allows us to test for binade boundaries.
10801109
const integerPart *Parts = significandParts();
@@ -1113,6 +1142,8 @@ bool IEEEFloat::isSignificandAllZerosExceptMSB() const {
11131142
}
11141143

11151144
bool IEEEFloat::isLargest() const {
1145+
if (semantics == &semFloat8E8M0FN)
1146+
return isFiniteNonZero() && exponent == semantics->maxExponent;
11161147
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
11171148
semantics->nanEncoding == fltNanEncoding::AllOnes) {
11181149
// The largest number by magnitude in our format will be the floating point
@@ -1165,6 +1196,12 @@ IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics, integerPart value) {
11651196

11661197
IEEEFloat::IEEEFloat(const fltSemantics &ourSemantics) {
11671198
initialize(&ourSemantics);
1199+
// The E8M0 type cannot represent the value zero.
1200+
// So, initialize with the closest representation instead.
1201+
if (semantics == &semFloat8E8M0FN) {
1202+
makeSmallestNormalized(false);
1203+
return;
1204+
}
11681205
makeZero(false);
11691206
}
11701207

@@ -1727,6 +1764,11 @@ IEEEFloat::opStatus IEEEFloat::normalize(roundingMode rounding_mode,
17271764
/* Canonicalize zeroes. */
17281765
if (omsb == 0) {
17291766
category = fcZero;
1767+
// The E8M0 type cannot represent the value zero and
1768+
// thus the category cannot be fcZero. So, get the
1769+
// closest representation to fcZero instead.
1770+
if (semantics == &semFloat8E8M0FN)
1771+
makeSmallestNormalized(false);
17301772
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
17311773
sign = false;
17321774
}
@@ -2606,6 +2648,11 @@ IEEEFloat::opStatus IEEEFloat::convert(const fltSemantics &toSemantics,
26062648
fs = opOK;
26072649
}
26082650

2651+
// The E8M0 type cannot represent the value zero and
2652+
// thus the category cannot be fcZero. So, get the
2653+
// closest representation to fcZero instead.
2654+
if (category == fcZero && semantics == &semFloat8E8M0FN)
2655+
makeSmallestNormalized(false);
26092656
return fs;
26102657
}
26112658

@@ -3070,6 +3117,11 @@ IEEEFloat::convertFromDecimalString(StringRef str, roundingMode rounding_mode) {
30703117
fs = opOK;
30713118
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
30723119
sign = false;
3120+
// The E8M0 type cannot represent the value zero and
3121+
// thus the category cannot be fcZero. So, get the
3122+
// closest representation to fcZero instead.
3123+
if (semantics == &semFloat8E8M0FN)
3124+
makeSmallestNormalized(false);
30733125

30743126
/* Check whether the normalized exponent is high enough to overflow
30753127
max during the log-rebasing in the max-exponent check below. */
@@ -3533,15 +3585,16 @@ APInt IEEEFloat::convertPPCDoubleDoubleAPFloatToAPInt() const {
35333585
template <const fltSemantics &S>
35343586
APInt IEEEFloat::convertIEEEFloatToAPInt() const {
35353587
assert(semantics == &S);
3536-
3537-
constexpr int bias = -(S.minExponent - 1);
3588+
const int bias =
3589+
(semantics == &semFloat8E8M0FN) ? -S.minExponent : -(S.minExponent - 1);
35383590
constexpr unsigned int trailing_significand_bits = S.precision - 1;
35393591
constexpr int integer_bit_part = trailing_significand_bits / integerPartWidth;
35403592
constexpr integerPart integer_bit =
35413593
integerPart{1} << (trailing_significand_bits % integerPartWidth);
35423594
constexpr uint64_t significand_mask = integer_bit - 1;
35433595
constexpr unsigned int exponent_bits =
3544-
S.sizeInBits - 1 - trailing_significand_bits;
3596+
trailing_significand_bits ? (S.sizeInBits - 1 - trailing_significand_bits)
3597+
: S.sizeInBits;
35453598
static_assert(exponent_bits < 64);
35463599
constexpr uint64_t exponent_mask = (uint64_t{1} << exponent_bits) - 1;
35473600

@@ -3557,6 +3610,8 @@ APInt IEEEFloat::convertIEEEFloatToAPInt() const {
35573610
!(significandParts()[integer_bit_part] & integer_bit))
35583611
myexponent = 0; // denormal
35593612
} else if (category == fcZero) {
3613+
if (semantics == &semFloat8E8M0FN)
3614+
llvm_unreachable("semantics does not support zero!");
35603615
myexponent = ::exponentZero(S) + bias;
35613616
mysignificand.fill(0);
35623617
} else if (category == fcInfinity) {
@@ -3659,6 +3714,11 @@ APInt IEEEFloat::convertFloatTF32APFloatToAPInt() const {
36593714
return convertIEEEFloatToAPInt<semFloatTF32>();
36603715
}
36613716

3717+
APInt IEEEFloat::convertFloat8E8M0FNAPFloatToAPInt() const {
3718+
assert(partCount() == 1);
3719+
return convertIEEEFloatToAPInt<semFloat8E8M0FN>();
3720+
}
3721+
36623722
APInt IEEEFloat::convertFloat6E3M2FNAPFloatToAPInt() const {
36633723
assert(partCount() == 1);
36643724
return convertIEEEFloatToAPInt<semFloat6E3M2FN>();
@@ -3721,6 +3781,9 @@ APInt IEEEFloat::bitcastToAPInt() const {
37213781
if (semantics == (const llvm::fltSemantics *)&semFloatTF32)
37223782
return convertFloatTF32APFloatToAPInt();
37233783

3784+
if (semantics == (const llvm::fltSemantics *)&semFloat8E8M0FN)
3785+
return convertFloat8E8M0FNAPFloatToAPInt();
3786+
37243787
if (semantics == (const llvm::fltSemantics *)&semFloat6E3M2FN)
37253788
return convertFloat6E3M2FNAPFloatToAPInt();
37263789

@@ -3819,6 +3882,40 @@ void IEEEFloat::initFromPPCDoubleDoubleAPInt(const APInt &api) {
38193882
}
38203883
}
38213884

3885+
// The E8M0 format has the following characteristics:
3886+
// It is an 8-bit unsigned format with only exponents (no actual significand)
3887+
// No encodings for {zero, infinities or denorms}
3888+
// NaN is represented by all 1's
3889+
// Bias is 127
3890+
void IEEEFloat::initFromFloat8E8M0FNAPInt(const APInt &api) {
3891+
const uint64_t exponent_mask = 0xff;
3892+
uint64_t val = api.getRawData()[0];
3893+
uint64_t myexponent = (val & exponent_mask);
3894+
3895+
initialize(&semFloat8E8M0FN);
3896+
assert(partCount() == 1);
3897+
3898+
// This format has unsigned representation only
3899+
sign = 0;
3900+
3901+
// Set the significand
3902+
// This format does not have any significand but the 'Pth' precision bit is
3903+
// always set to 1 for consistency in APFloat's internal representation.
3904+
uint64_t mysignificand = 1;
3905+
significandParts()[0] = mysignificand;
3906+
3907+
// This format can either have a NaN or fcNormal
3908+
// All 1's i.e. 255 is a NaN
3909+
if (val == exponent_mask) {
3910+
category = fcNaN;
3911+
exponent = exponentNaN();
3912+
return;
3913+
}
3914+
// Handle fcNormal...
3915+
category = fcNormal;
3916+
exponent = myexponent - 127; // 127 is bias
3917+
return;
3918+
}
38223919
template <const fltSemantics &S>
38233920
void IEEEFloat::initFromIEEEAPInt(const APInt &api) {
38243921
assert(api.getBitWidth() == S.sizeInBits);
@@ -3999,6 +4096,8 @@ void IEEEFloat::initFromAPInt(const fltSemantics *Sem, const APInt &api) {
39994096
return initFromFloat8E3M4APInt(api);
40004097
if (Sem == &semFloatTF32)
40014098
return initFromFloatTF32APInt(api);
4099+
if (Sem == &semFloat8E8M0FN)
4100+
return initFromFloat8E8M0FNAPInt(api);
40024101
if (Sem == &semFloat6E3M2FN)
40034102
return initFromFloat6E3M2FNAPInt(api);
40044103
if (Sem == &semFloat6E2M3FN)
@@ -4032,6 +4131,13 @@ void IEEEFloat::makeLargest(bool Negative) {
40324131
significand[PartCount - 1] = (NumUnusedHighBits < integerPartWidth)
40334132
? (~integerPart(0) >> NumUnusedHighBits)
40344133
: 0;
4134+
// For E8M0 format, we only have the 'internal' precision bit
4135+
// (aka 'P' the precision bit) which is always set to 1.
4136+
// Hence, the below logic of setting the LSB to 0 does not apply.
4137+
// For other cases, the LSB is meant to be any bit other than
4138+
// the Pth precision bit.
4139+
if (semantics == &semFloat8E8M0FN)
4140+
return;
40354141

40364142
if (semantics->nonFiniteBehavior == fltNonfiniteBehavior::NanOnly &&
40374143
semantics->nanEncoding == fltNanEncoding::AllOnes)
@@ -4509,6 +4615,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
45094615
exponent = 0;
45104616
if (semantics->nanEncoding == fltNanEncoding::NegativeZero)
45114617
sign = false;
4618+
// The E8M0 type cannot represent the value zero and
4619+
// thus the category cannot be fcZero. So, get the
4620+
// closest representation to fcZero instead.
4621+
if (semantics == &semFloat8E8M0FN)
4622+
makeSmallestNormalized(false);
45124623
break;
45134624
}
45144625

@@ -4575,6 +4686,11 @@ IEEEFloat::opStatus IEEEFloat::next(bool nextDown) {
45754686
// denormal always increment since moving denormals and the numbers in the
45764687
// smallest normal binade have the same exponent in our representation.
45774688
bool WillCrossBinadeBoundary = !isDenormal() && isSignificandAllOnes();
4689+
// The E8M0 format does not support Denorms.
4690+
// Since there are only exponents, any increment always crosses the
4691+
// 'BinadeBoundary'. So, make this true always.
4692+
if (semantics == &semFloat8E8M0FN)
4693+
WillCrossBinadeBoundary = true;
45784694

45794695
if (WillCrossBinadeBoundary) {
45804696
integerPart *Parts = significandParts();
@@ -4626,6 +4742,11 @@ void IEEEFloat::makeInf(bool Negative) {
46264742
}
46274743

46284744
void IEEEFloat::makeZero(bool Negative) {
4745+
// The E8M0 type cannot represent the value zero.
4746+
if (semantics == &semFloat8E8M0FN) {
4747+
assert(false && "This floating point format does not support Zero\n");
4748+
return;
4749+
}
46294750
category = fcZero;
46304751
sign = Negative;
46314752
if (semantics->nanEncoding == fltNanEncoding::NegativeZero) {

0 commit comments

Comments
 (0)