Skip to content

Commit d658ef1

Browse files
vchernon-inteligcbot
authored andcommitted
process NaN in fp emulation
fp2ui should return 0 for NaN values
1 parent 6aadb96 commit d658ef1

File tree

1 file changed

+29
-10
lines changed

1 file changed

+29
-10
lines changed

IGC/VectorCompiler/lib/BiF/fp2ui_conversion.cpp

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -222,17 +222,17 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N>
222222
__impl_fp2ui__double__(vector<double, N> a) {
223223
// vector of floats -> vector of ints
224224
vector<uint32_t, 2 *N> LoHi = a.template format<uint32_t>();
225-
const vector<uint32_t, N> Exp_mask(0xff << 20);
226-
const vector<uint32_t, N> Mantissa_mask((1u << 20) - 1);
225+
const vector<uint32_t, N> MantissaMask((1u << 20) - 1);
226+
const vector<uint32_t, N> ExpMask(0x7ff);
227227
const vector<uint32_t, N> Zero(0);
228228
const vector<uint32_t, N> Ones(0xffffffff);
229229
const vector<uint32_t, N> One(1);
230230
vector<uint32_t, N> Lo = LoHi.template select<N, 2>(0);
231231
vector<uint32_t, N> Hi = LoHi.template select<N, 2>(1);
232-
vector<uint32_t, N> Exp = (Hi >> 20) & vector<uint32_t, N>(0x7ff);
232+
vector<uint32_t, N> Exp = (Hi >> 20) & ExpMask;
233233
// mantissa without hidden bit
234234
vector<uint32_t, N> LoMant = Lo;
235-
vector<uint32_t, N> HiMant = Hi & Mantissa_mask;
235+
vector<uint32_t, N> HiMant = Hi & MantissaMask;
236236
// for normalized numbers (1 + mant/2^52) * 2 ^ (mant-1023)
237237
vector<int32_t, N> MantShift = Exp - 1023 - 52;
238238
vector<int32_t, N> OneShift = Exp - 1023;
@@ -272,6 +272,7 @@ __impl_fp2ui__double__(vector<double, N> a) {
272272
// check for Exponent overflow (when sign bit set)
273273
auto FlagExpO = (Exp > vector<uint32_t, N>(1089));
274274
auto FlagExpUO = FlagNoSignSet & FlagExpO;
275+
auto IsNaN = (Exp == ExpMask) & ((LoMant != Zero) | (HiMant != Zero));
275276
if constexpr (isSigned) {
276277
// calculate (NOT[Lo, Hi] + 1) (integer sign negation)
277278
vector<uint32_t, N> NegLo = ~LoRes;
@@ -307,29 +308,38 @@ __impl_fp2ui__double__(vector<double, N> a) {
307308
LoRes.merge(Ones, FlagExpUO);
308309
HiRes.merge(vector<uint32_t, N>((1u << 31) - 1), FlagExpUO);
309310

311+
// if (IsNaN)
312+
LoRes.merge(Zero, IsNaN);
313+
HiRes.merge(Zero, IsNaN);
314+
310315
} else {
311316
// if (FlagSignSet)
312317
LoRes.merge(Zero, FlagSignSet);
313318
HiRes.merge(Zero, FlagSignSet);
319+
314320
// if (FlagExpUO)
315321
LoRes.merge(Ones, FlagExpUO);
316322
HiRes.merge(Ones, FlagExpUO);
323+
324+
// if (IsNaN)
325+
LoRes.merge(Zero, IsNaN);
326+
HiRes.merge(Zero, IsNaN);
317327
}
318328
return __impl_combineLoHi<N>(LoRes, HiRes);
319329
}
320330
template <unsigned N, bool isSigned>
321331
CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
322332
// vector of floats -> vector of ints
323333
vector<uint32_t, N> Uifl = a.template format<uint32_t>();
324-
const vector<uint32_t, N> Exp_mask(0xff << 23);
325-
const vector<uint32_t, N> Mantissa_mask((1u << 23) - 1);
334+
const vector<uint32_t, N> ExpMask(0xff);
335+
const vector<uint32_t, N> MantissaMask((1u << 23) - 1);
326336
const vector<uint32_t, N> Zero(0);
327337
const vector<uint32_t, N> Ones(0xffffffff);
328338
const vector<uint32_t, N> One(1);
329339

330-
vector<uint32_t, N> Exp = (Uifl >> 23) & vector<uint32_t, N>(0xff);
340+
vector<uint32_t, N> Exp = (Uifl >> 23) & ExpMask;
331341
// mantissa without hidden bit
332-
vector<uint32_t, N> Pmantissa = Uifl & Mantissa_mask;
342+
vector<uint32_t, N> Pmantissa = Uifl & MantissaMask;
333343
// take hidden bit into account
334344
vector<uint32_t, N> Mantissa = Pmantissa | vector<uint32_t, N>(1 << 23);
335345
vector<uint32_t, N> Data_h = Mantissa << 8;
@@ -354,8 +364,8 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
354364

355365
// Discard results if shift is greater than 63
356366
vector<uint32_t, N> Mask = Ones;
357-
auto Flag_discard = (Shift > vector<uint32_t, N>(63));
358-
Mask.merge(Zero, Flag_discard);
367+
auto FlagDiscard = (Shift > vector<uint32_t, N>(63));
368+
Mask.merge(Zero, FlagDiscard);
359369
Lo = Lo & Mask;
360370
Hi = Hi & Mask;
361371
vector<uint32_t, N> SignedBitMask(1u << 31);
@@ -365,6 +375,7 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
365375
// check for Exponent overflow (when sign bit set)
366376
auto FlagExpO = (Exp > vector<uint32_t, N>(0xbe));
367377
auto FlagExpUO = FlagNoSignSet & FlagExpO;
378+
auto IsNaN = (Exp == ExpMask) & (Pmantissa != Zero);
368379
if constexpr (isSigned) {
369380
// calculate (NOT[Lo, Hi] + 1) (integer sign negation)
370381
vector<uint32_t, N> NegLo = ~Lo;
@@ -401,13 +412,21 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
401412
Lo.merge(Ones, FlagExpUO);
402413
Hi.merge(vector<uint32_t, N>((1u << 31) - 1), FlagExpUO);
403414

415+
// if (IsNaN)
416+
Lo.merge(Zero, IsNaN);
417+
Hi.merge(Zero, IsNaN);
404418
} else {
405419
// if (FlagSignSet)
406420
Lo.merge(Zero, FlagSignSet);
407421
Hi.merge(Zero, FlagSignSet);
422+
408423
// if (FlagExpUO)
409424
Lo.merge(Ones, FlagExpUO);
410425
Hi.merge(Ones, FlagExpUO);
426+
427+
// if (IsNaN)
428+
Lo.merge(Zero, IsNaN);
429+
Hi.merge(Zero, IsNaN);
411430
}
412431
return __impl_combineLoHi<N>(Lo, Hi);
413432
}

0 commit comments

Comments
 (0)