@@ -222,17 +222,17 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N>
222
222
__impl_fp2ui__double__ (vector<double , N> a) {
223
223
// vector of floats -> vector of ints
224
224
vector<uint32_t , 2 *N> LoHi = a.template format <uint32_t >();
225
- const vector<uint32_t , N> Exp_mask ( 0xff << 20 );
226
- const vector<uint32_t , N> Mantissa_mask (( 1u << 20 ) - 1 );
225
+ const vector<uint32_t , N> MantissaMask (( 1u << 20 ) - 1 );
226
+ const vector<uint32_t , N> ExpMask ( 0x7ff );
227
227
const vector<uint32_t , N> Zero (0 );
228
228
const vector<uint32_t , N> Ones (0xffffffff );
229
229
const vector<uint32_t , N> One (1 );
230
230
vector<uint32_t , N> Lo = LoHi.template select <N, 2 >(0 );
231
231
vector<uint32_t , N> Hi = LoHi.template select <N, 2 >(1 );
232
- vector<uint32_t , N> Exp = (Hi >> 20 ) & vector< uint32_t , N>( 0x7ff ) ;
232
+ vector<uint32_t , N> Exp = (Hi >> 20 ) & ExpMask ;
233
233
// mantissa without hidden bit
234
234
vector<uint32_t , N> LoMant = Lo;
235
- vector<uint32_t , N> HiMant = Hi & Mantissa_mask ;
235
+ vector<uint32_t , N> HiMant = Hi & MantissaMask ;
236
236
// for normalized numbers (1 + mant/2^52) * 2 ^ (mant-1023)
237
237
vector<int32_t , N> MantShift = Exp - 1023 - 52 ;
238
238
vector<int32_t , N> OneShift = Exp - 1023 ;
@@ -272,6 +272,7 @@ __impl_fp2ui__double__(vector<double, N> a) {
272
272
// check for Exponent overflow (when sign bit set)
273
273
auto FlagExpO = (Exp > vector<uint32_t , N>(1089 ));
274
274
auto FlagExpUO = FlagNoSignSet & FlagExpO;
275
+ auto IsNaN = (Exp == ExpMask) & ((LoMant != Zero) | (HiMant != Zero));
275
276
if constexpr (isSigned) {
276
277
// calculate (NOT[Lo, Hi] + 1) (integer sign negation)
277
278
vector<uint32_t , N> NegLo = ~LoRes;
@@ -307,29 +308,38 @@ __impl_fp2ui__double__(vector<double, N> a) {
307
308
LoRes.merge (Ones, FlagExpUO);
308
309
HiRes.merge (vector<uint32_t , N>((1u << 31 ) - 1 ), FlagExpUO);
309
310
311
+ // if (IsNaN)
312
+ LoRes.merge (Zero, IsNaN);
313
+ HiRes.merge (Zero, IsNaN);
314
+
310
315
} else {
311
316
// if (FlagSignSet)
312
317
LoRes.merge (Zero, FlagSignSet);
313
318
HiRes.merge (Zero, FlagSignSet);
319
+
314
320
// if (FlagExpUO)
315
321
LoRes.merge (Ones, FlagExpUO);
316
322
HiRes.merge (Ones, FlagExpUO);
323
+
324
+ // if (IsNaN)
325
+ LoRes.merge (Zero, IsNaN);
326
+ HiRes.merge (Zero, IsNaN);
317
327
}
318
328
return __impl_combineLoHi<N>(LoRes, HiRes);
319
329
}
320
330
template <unsigned N, bool isSigned>
321
331
CM_NODEBUG CM_INLINE vector<uint64_t , N> __impl_fp2ui__ (vector<float , N> a) {
322
332
// vector of floats -> vector of ints
323
333
vector<uint32_t , N> Uifl = a.template format <uint32_t >();
324
- const vector<uint32_t , N> Exp_mask (0xff << 23 );
325
- const vector<uint32_t , N> Mantissa_mask ((1u << 23 ) - 1 );
334
+ const vector<uint32_t , N> ExpMask (0xff );
335
+ const vector<uint32_t , N> MantissaMask ((1u << 23 ) - 1 );
326
336
const vector<uint32_t , N> Zero (0 );
327
337
const vector<uint32_t , N> Ones (0xffffffff );
328
338
const vector<uint32_t , N> One (1 );
329
339
330
- vector<uint32_t , N> Exp = (Uifl >> 23 ) & vector< uint32_t , N>( 0xff ) ;
340
+ vector<uint32_t , N> Exp = (Uifl >> 23 ) & ExpMask ;
331
341
// mantissa without hidden bit
332
- vector<uint32_t , N> Pmantissa = Uifl & Mantissa_mask ;
342
+ vector<uint32_t , N> Pmantissa = Uifl & MantissaMask ;
333
343
// take hidden bit into account
334
344
vector<uint32_t , N> Mantissa = Pmantissa | vector<uint32_t , N>(1 << 23 );
335
345
vector<uint32_t , N> Data_h = Mantissa << 8 ;
@@ -354,8 +364,8 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
354
364
355
365
// Discard results if shift is greater than 63
356
366
vector<uint32_t , N> Mask = Ones;
357
- auto Flag_discard = (Shift > vector<uint32_t , N>(63 ));
358
- Mask.merge (Zero, Flag_discard );
367
+ auto FlagDiscard = (Shift > vector<uint32_t , N>(63 ));
368
+ Mask.merge (Zero, FlagDiscard );
359
369
Lo = Lo & Mask;
360
370
Hi = Hi & Mask;
361
371
vector<uint32_t , N> SignedBitMask (1u << 31 );
@@ -365,6 +375,7 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
365
375
// check for Exponent overflow (when sign bit set)
366
376
auto FlagExpO = (Exp > vector<uint32_t , N>(0xbe ));
367
377
auto FlagExpUO = FlagNoSignSet & FlagExpO;
378
+ auto IsNaN = (Exp == ExpMask) & (Pmantissa != Zero);
368
379
if constexpr (isSigned) {
369
380
// calculate (NOT[Lo, Hi] + 1) (integer sign negation)
370
381
vector<uint32_t , N> NegLo = ~Lo;
@@ -401,13 +412,21 @@ CM_NODEBUG CM_INLINE vector<uint64_t, N> __impl_fp2ui__(vector<float, N> a) {
401
412
Lo.merge (Ones, FlagExpUO);
402
413
Hi.merge (vector<uint32_t , N>((1u << 31 ) - 1 ), FlagExpUO);
403
414
415
+ // if (IsNaN)
416
+ Lo.merge (Zero, IsNaN);
417
+ Hi.merge (Zero, IsNaN);
404
418
} else {
405
419
// if (FlagSignSet)
406
420
Lo.merge (Zero, FlagSignSet);
407
421
Hi.merge (Zero, FlagSignSet);
422
+
408
423
// if (FlagExpUO)
409
424
Lo.merge (Ones, FlagExpUO);
410
425
Hi.merge (Ones, FlagExpUO);
426
+
427
+ // if (IsNaN)
428
+ Lo.merge (Zero, IsNaN);
429
+ Hi.merge (Zero, IsNaN);
411
430
}
412
431
return __impl_combineLoHi<N>(Lo, Hi);
413
432
}
0 commit comments