Skip to content

Commit 052eb47

Browse files
aratajewigcbot
authored andcommitted
Replace incorrect implementation of mad_sat builtin with libclc version
IGC mad_sat implementation was returning wrong results for some input data. It's assumptions regarding higher part of multiplication result were incorrect which resulted in returning LONG_MIN even though result value was in saturation range, so should not be clamped. To avoid reinventing the wheel, it's better to reuse mad_sat implementation from libclc which should already be deeply tested.
1 parent 39fa377 commit 052eb47

File tree

4 files changed

+69
-17
lines changed

4 files changed

+69
-17
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#include "../../../Headers/spirv.h"
2+
3+
INLINE OVERLOADABLE long libclc_hadd(long x, long y) {
4+
return (x >> (long)1) + (y >> (long)1) + (x & y & (long)1);
5+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#include "../../include/BiF_Definitions.cl"
2+
#include "../../../Headers/spirv.h"
3+
#include "mul_hi.cl"
4+
5+
INLINE OVERLOADABLE long libclc_mad_sat(long x, long y, long z) {
6+
long hi = libclc_mul_hi(x, y);
7+
ulong ulo = x * y;
8+
long slo = x * y;
9+
/* Big overflow of more than 2 bits, add can't fix this */
10+
if (((x < 0) == (y < 0)) && hi != 0)
11+
return LONG_MAX;
12+
/* Low overflow in mul and z not neg enough to correct it */
13+
if (hi == 0 && ulo >= LONG_MAX && (z > 0 || (ulo + z) > LONG_MAX))
14+
return LONG_MAX;
15+
/* Big overflow of more than 2 bits, add can't fix this */
16+
if (((x < 0) != (y < 0)) && hi != -1)
17+
return LONG_MIN;
18+
/* Low overflow in mul and z not pos enough to correct it */
19+
if (hi == -1 && ulo <= ((ulong)LONG_MAX + 1UL) && (z < 0 || z < (LONG_MAX - ulo)))
20+
return LONG_MIN;
21+
/* We have checked all conditions, any overflow in addition returns
22+
* the correct value */
23+
return ulo + z;
24+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "../../include/BiF_Definitions.cl"
2+
#include "../../../Headers/spirv.h"
3+
#include "hadd.cl"
4+
5+
//FOIL-based long mul_hi
6+
//
7+
// Summary: Treat mul_hi(long x, long y) as:
8+
// (a+b) * (c+d) where a and c are the high-order parts of x and y respectively
9+
// and b and d are the low-order parts of x and y.
10+
// Thinking back to algebra, we use FOIL to do the work.
11+
12+
INLINE OVERLOADABLE long libclc_mul_hi(long x, long y) {
13+
long f, o, i;
14+
ulong l;
15+
16+
//Move the high/low halves of x/y into the lower 32-bits of variables so
17+
//that we can multiply them without worrying about overflow.
18+
long x_hi = x >> 32;
19+
long x_lo = x & UINT_MAX;
20+
long y_hi = y >> 32;
21+
long y_lo = y & UINT_MAX;
22+
23+
//Multiply all of the components according to FOIL method
24+
f = x_hi * y_hi;
25+
o = x_hi * y_lo;
26+
i = x_lo * y_hi;
27+
l = x_lo * y_lo;
28+
29+
//Now add the components back together in the following steps:
30+
//F: doesn't need to be modified
31+
//O/I: Need to be added together.
32+
//L: Shift right by 32-bits, then add into the sum of O and I
33+
//Once O/I/L are summed up, then shift the sum by 32-bits and add to F.
34+
//
35+
//We use hadd to give us a bit of extra precision for the intermediate sums
36+
//but as a result, we shift by 31 bits instead of 32
37+
return (long)(f + (libclc_hadd(o, (i + (long)((ulong)l >> 32))) >> 31));
38+
}

IGC/BiFModule/Implementation/Integer/mad_sat.cl

Lines changed: 2 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ SPDX-License-Identifier: MIT
99
#include "../include/BiF_Definitions.cl"
1010
#include "../../Headers/spirv.h"
1111
#include "../include/mul_hilo.cl"
12+
#include "../ExternalLibraries/libclc/mad_sat.cl"
1213

1314
INLINE
1415
char2 __builtin_spirv_OpenCL_s_mad_sat_v2i8_v2i8_v2i8( char2 a,
@@ -490,23 +491,7 @@ long __builtin_spirv_OpenCL_s_mad_sat_i64_i64_i64( long a,
490491
long b,
491492
long c )
492493
{
493-
long lo;
494-
long hi;
495-
hi = __builtin_spirv___intc_mul_hilo_i64_i64_p0i64(a, b, &lo);
496-
long result_lo = lo + c;
497-
if (c >= 0)
498-
{
499-
if (result_lo < lo)
500-
hi++;
501-
}
502-
else
503-
{
504-
if (result_lo > lo)
505-
hi--;
506-
}
507-
return (hi == 0) ? result_lo :
508-
(hi < 0) ? LONG_MIN :
509-
LONG_MAX;
494+
return libclc_mad_sat(a, b, c);
510495
}
511496

512497
INLINE

0 commit comments

Comments
 (0)