Skip to content

Commit 8a3b7d9

Browse files
sergey-semenovbader
authored andcommitted
[SYCL] Fix the mul_hi built-in on host device
This patch fixes incorrect handling of negative arguments in host implementation of the mul_hi built-in. Signed-off-by: Sergey Semenov <[email protected]>
1 parent ab3e71e commit 8a3b7d9

File tree

2 files changed

+92
-21
lines changed

2 files changed

+92
-21
lines changed

sycl/source/detail/builtins_integer.cpp

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,20 @@ template <typename T> T __mul_hi(T a, T b) {
7979
return (mul >> (sizeof(T) * 8));
8080
}
8181

82-
// T is minimum of 64 bits- long or longlong
83-
template <typename T> inline T __long_mul_hi(T a, T b) {
84-
int halfsize = (sizeof(T) * 8) / 2;
82+
// A helper function for mul_hi built-in for long
83+
template <typename T> inline T __get_high_half(T a0b0, T a0b1, T a1b0, T a1b1) {
84+
constexpr int halfsize = (sizeof(T) * 8) / 2;
85+
// To get the upper 64 bits:
86+
// 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
87+
// in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
88+
// the above 32 bit result.
89+
return a1b1 + (__hadd(a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1));
90+
}
91+
92+
// A helper function for mul_hi built-in for long
93+
template <typename T>
94+
inline void __get_half_products(T a, T b, T &a0b0, T &a0b1, T &a1b0, T &a1b1) {
95+
constexpr int halfsize = (sizeof(T) * 8) / 2;
8596
T a1 = a >> halfsize;
8697
T a0 = (a << halfsize) >> halfsize;
8798
T b1 = b >> halfsize;
@@ -90,26 +101,53 @@ template <typename T> inline T __long_mul_hi(T a, T b) {
90101
// a1b1 - for bits - [64-128)
91102
// a1b0 a0b1 for bits - [32-96)
92103
// a0b0 for bits - [0-64)
93-
T a1b1 = a1 * b1;
94-
T a0b1 = a0 * b1;
95-
T a1b0 = a1 * b0;
96-
T a0b0 = a0 * b0;
104+
a1b1 = a1 * b1;
105+
a0b1 = a0 * b1;
106+
a1b0 = a1 * b0;
107+
a0b0 = a0 * b0;
108+
}
109+
110+
// T is minimum of 64 bits- long or longlong
111+
template <typename T> inline T __u_long_mul_hi(T a, T b) {
112+
T a0b0, a0b1, a1b0, a1b1;
113+
__get_half_products(a, b, a0b0, a0b1, a1b0, a1b1);
114+
T result = __get_high_half(a0b0, a0b1, a1b0, a1b1);
115+
return result;
116+
}
117+
118+
template <typename T> inline T __s_long_mul_hi(T a, T b) {
119+
using UT = typename std::make_unsigned<T>::type;
120+
UT absA = std::abs(a);
121+
UT absB = std::abs(b);
122+
123+
UT a0b0, a0b1, a1b0, a1b1;
124+
__get_half_products(absA, absB, a0b0, a0b1, a1b0, a1b1);
125+
T result = __get_high_half(a0b0, a0b1, a1b0, a1b1);
126+
127+
bool isResultNegative = (a < 0) != (b < 0);
128+
if (isResultNegative) {
129+
result = ~result;
130+
131+
// Find the low half to see if we need to carry
132+
constexpr int halfsize = (sizeof(T) * 8) / 2;
133+
UT low = a0b0 + ((a0b1 + a1b0) << halfsize);
134+
if (low == 0)
135+
++result;
136+
}
97137

98-
// To get the upper 64 bits:
99-
// 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
100-
// in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
101-
// the above 32 bit result.
102-
T result =
103-
a1b1 + (__hadd(a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1));
104138
return result;
105139
}
106140

107141
template <typename T> inline T __mad_hi(T a, T b, T c) {
108142
return __mul_hi(a, b) + c;
109143
}
110144

111-
template <typename T> inline T __long_mad_hi(T a, T b, T c) {
112-
return __long_mul_hi(a, b) + c;
145+
template <typename T> inline T __u_long_mad_hi(T a, T b, T c) {
146+
return __u_long_mul_hi(a, b) + c;
147+
}
148+
149+
template <typename T> inline T __s_long_mad_hi(T a, T b, T c) {
150+
return __s_long_mul_hi(a, b) + c;
113151
}
114152

115153
template <typename T> inline T __s_mad_sat(T a, T b, T c) {
@@ -123,7 +161,7 @@ template <typename T> inline T __s_mad_sat(T a, T b, T c) {
123161

124162
template <typename T> inline T __s_long_mad_sat(T a, T b, T c) {
125163
bool neg_prod = (a < 0) ^ (b < 0);
126-
T mulhi = __long_mul_hi(a, b);
164+
T mulhi = __s_long_mul_hi(a, b);
127165

128166
// check mul_hi. If it is any value != 0.
129167
// if prod is +ve, any value in mulhi means we need to saturate.
@@ -145,7 +183,7 @@ template <typename T> inline T __u_mad_sat(T a, T b, T c) {
145183
}
146184

147185
template <typename T> inline T __u_long_mad_sat(T a, T b, T c) {
148-
T mulhi = __long_mul_hi(a, b);
186+
T mulhi = __u_long_mul_hi(a, b);
149187
// check mul_hi. If it is any value != 0.
150188
if (mulhi != 0)
151189
return d::max_v<T>();
@@ -421,7 +459,7 @@ cl_char s_mul_hi(cl_char a, cl_char b) { return __mul_hi(a, b); }
421459
cl_short s_mul_hi(cl_short a, cl_short b) { return __mul_hi(a, b); }
422460
cl_int s_mul_hi(cl_int a, cl_int b) { return __mul_hi(a, b); }
423461
cl_long s_mul_hi(s::cl_long x, s::cl_long y) __NOEXC {
424-
return __long_mul_hi(x, y);
462+
return __s_long_mul_hi(x, y);
425463
}
426464
MAKE_1V_2V(s_mul_hi, s::cl_char, s::cl_char, s::cl_char)
427465
MAKE_1V_2V(s_mul_hi, s::cl_short, s::cl_short, s::cl_short)
@@ -433,7 +471,7 @@ cl_uchar u_mul_hi(cl_uchar a, cl_uchar b) { return __mul_hi(a, b); }
433471
cl_ushort u_mul_hi(cl_ushort a, cl_ushort b) { return __mul_hi(a, b); }
434472
cl_uint u_mul_hi(cl_uint a, cl_uint b) { return __mul_hi(a, b); }
435473
cl_ulong u_mul_hi(s::cl_ulong x, s::cl_ulong y) __NOEXC {
436-
return __long_mul_hi(x, y);
474+
return __u_long_mul_hi(x, y);
437475
}
438476
MAKE_1V_2V(u_mul_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar)
439477
MAKE_1V_2V(u_mul_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort)
@@ -452,7 +490,7 @@ cl_int s_mad_hi(s::cl_int x, s::cl_int minval, s::cl_int maxval) __NOEXC {
452490
return __mad_hi(x, minval, maxval);
453491
}
454492
cl_long s_mad_hi(s::cl_long x, s::cl_long minval, s::cl_long maxval) __NOEXC {
455-
return __long_mad_hi(x, minval, maxval);
493+
return __s_long_mad_hi(x, minval, maxval);
456494
}
457495
MAKE_1V_2V_3V(s_mad_hi, s::cl_char, s::cl_char, s::cl_char, s::cl_char)
458496
MAKE_1V_2V_3V(s_mad_hi, s::cl_short, s::cl_short, s::cl_short, s::cl_short)
@@ -473,7 +511,7 @@ cl_uint u_mad_hi(s::cl_uint x, s::cl_uint minval, s::cl_uint maxval) __NOEXC {
473511
}
474512
cl_ulong u_mad_hi(s::cl_ulong x, s::cl_ulong minval,
475513
s::cl_ulong maxval) __NOEXC {
476-
return __long_mad_hi(x, minval, maxval);
514+
return __u_long_mad_hi(x, minval, maxval);
477515
}
478516
MAKE_1V_2V_3V(u_mad_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar, s::cl_uchar)
479517
MAKE_1V_2V_3V(u_mad_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort, s::cl_ushort)

sycl/test/built-ins/scalar_integer.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,39 @@ int main() {
303303
assert(r == 0x10);
304304
}
305305

306+
// mul_hi with negative result w/ carry
307+
{
308+
s::cl_int r{0};
309+
{
310+
s::buffer<s::cl_int, 1> BufR(&r, s::range<1>(1));
311+
s::queue myQueue;
312+
myQueue.submit([&](s::handler &cgh) {
313+
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
314+
cgh.single_task<class mul_hiSI1SI2>([=]() {
315+
AccR[0] = s::mul_hi(s::cl_int{-0x10000000}, s::cl_int{0x00000100});
316+
}); // -2^28 * 2^8 = -2^36 -> -0x10 (FFFFFFF0) 00000000.
317+
});
318+
}
319+
assert(r == -0x10);
320+
}
321+
322+
// mul_hi with negative result w/o carry
323+
{
324+
s::cl_int r{0};
325+
{
326+
s::buffer<s::cl_int, 1> BufR(&r, s::range<1>(1));
327+
s::queue myQueue;
328+
myQueue.submit([&](s::handler &cgh) {
329+
auto AccR = BufR.get_access<s::access::mode::write>(cgh);
330+
cgh.single_task<class mul_hiSI1SI3>([=]() {
331+
AccR[0] = s::mul_hi(s::cl_int{-0x10000000}, s::cl_int{0x00000101});
332+
}); // -2^28 * (2^8 + 1) = -2^36 - 2^28 -> -0x11 (FFFFFFEF) -0x10000000
333+
// (F0000000).
334+
});
335+
}
336+
assert(r == -0x11);
337+
}
338+
306339
// rotate
307340
{
308341
s::cl_int r{ 0 };

0 commit comments

Comments
 (0)