@@ -79,9 +79,20 @@ template <typename T> T __mul_hi(T a, T b) {
79
79
return (mul >> (sizeof (T) * 8 ));
80
80
}
81
81
82
- // T is minimum of 64 bits- long or longlong
83
- template <typename T> inline T __long_mul_hi (T a, T b) {
84
- int halfsize = (sizeof (T) * 8 ) / 2 ;
82
+ // A helper function for mul_hi built-in for long
83
+ template <typename T> inline T __get_high_half (T a0b0, T a0b1, T a1b0, T a1b1) {
84
+ constexpr int halfsize = (sizeof (T) * 8 ) / 2 ;
85
+ // To get the upper 64 bits:
86
+ // 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
87
+ // in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
88
+ // the above 32 bit result.
89
+ return a1b1 + (__hadd (a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1 ));
90
+ }
91
+
92
+ // A helper function for mul_hi built-in for long
93
+ template <typename T>
94
+ inline void __get_half_products (T a, T b, T &a0b0, T &a0b1, T &a1b0, T &a1b1) {
95
+ constexpr int halfsize = (sizeof (T) * 8 ) / 2 ;
85
96
T a1 = a >> halfsize;
86
97
T a0 = (a << halfsize) >> halfsize;
87
98
T b1 = b >> halfsize;
@@ -90,26 +101,53 @@ template <typename T> inline T __long_mul_hi(T a, T b) {
90
101
// a1b1 - for bits - [64-128)
91
102
// a1b0 a0b1 for bits - [32-96)
92
103
// a0b0 for bits - [0-64)
93
- T a1b1 = a1 * b1;
94
- T a0b1 = a0 * b1;
95
- T a1b0 = a1 * b0;
96
- T a0b0 = a0 * b0;
104
+ a1b1 = a1 * b1;
105
+ a0b1 = a0 * b1;
106
+ a1b0 = a1 * b0;
107
+ a0b0 = a0 * b0;
108
+ }
109
+
110
+ // T is minimum of 64 bits- long or longlong
111
+ template <typename T> inline T __u_long_mul_hi (T a, T b) {
112
+ T a0b0, a0b1, a1b0, a1b1;
113
+ __get_half_products (a, b, a0b0, a0b1, a1b0, a1b1);
114
+ T result = __get_high_half (a0b0, a0b1, a1b0, a1b1);
115
+ return result;
116
+ }
117
+
118
+ template <typename T> inline T __s_long_mul_hi (T a, T b) {
119
+ using UT = typename std::make_unsigned<T>::type;
120
+ UT absA = std::abs (a);
121
+ UT absB = std::abs (b);
122
+
123
+ UT a0b0, a0b1, a1b0, a1b1;
124
+ __get_half_products (absA, absB, a0b0, a0b1, a1b0, a1b1);
125
+ T result = __get_high_half (a0b0, a0b1, a1b0, a1b1);
126
+
127
+ bool isResultNegative = (a < 0 ) != (b < 0 );
128
+ if (isResultNegative) {
129
+ result = ~result;
130
+
131
+ // Find the low half to see if we need to carry
132
+ constexpr int halfsize = (sizeof (T) * 8 ) / 2 ;
133
+ UT low = a0b0 + ((a0b1 + a1b0) << halfsize);
134
+ if (low == 0 )
135
+ ++result;
136
+ }
97
137
98
- // To get the upper 64 bits:
99
- // 64 bits from a1b1, upper 32 bits from [a1b0 + (a0b1 + a0b0>>32 (carry bit
100
- // in 33rd bit))] with carry bit on 64th bit - use of hadd. Add the a1b1 to
101
- // the above 32 bit result.
102
- T result =
103
- a1b1 + (__hadd (a1b0, (a0b1 + (a0b0 >> halfsize))) >> (halfsize - 1 ));
104
138
return result;
105
139
}
106
140
107
141
template <typename T> inline T __mad_hi (T a, T b, T c) {
108
142
return __mul_hi (a, b) + c;
109
143
}
110
144
111
- template <typename T> inline T __long_mad_hi (T a, T b, T c) {
112
- return __long_mul_hi (a, b) + c;
145
+ template <typename T> inline T __u_long_mad_hi (T a, T b, T c) {
146
+ return __u_long_mul_hi (a, b) + c;
147
+ }
148
+
149
+ template <typename T> inline T __s_long_mad_hi (T a, T b, T c) {
150
+ return __s_long_mul_hi (a, b) + c;
113
151
}
114
152
115
153
template <typename T> inline T __s_mad_sat (T a, T b, T c) {
@@ -123,7 +161,7 @@ template <typename T> inline T __s_mad_sat(T a, T b, T c) {
123
161
124
162
template <typename T> inline T __s_long_mad_sat (T a, T b, T c) {
125
163
bool neg_prod = (a < 0 ) ^ (b < 0 );
126
- T mulhi = __long_mul_hi (a, b);
164
+ T mulhi = __s_long_mul_hi (a, b);
127
165
128
166
// check mul_hi. If it is any value != 0.
129
167
// if prod is +ve, any value in mulhi means we need to saturate.
@@ -145,7 +183,7 @@ template <typename T> inline T __u_mad_sat(T a, T b, T c) {
145
183
}
146
184
147
185
template <typename T> inline T __u_long_mad_sat (T a, T b, T c) {
148
- T mulhi = __long_mul_hi (a, b);
186
+ T mulhi = __u_long_mul_hi (a, b);
149
187
// check mul_hi. If it is any value != 0.
150
188
if (mulhi != 0 )
151
189
return d::max_v<T>();
@@ -421,7 +459,7 @@ cl_char s_mul_hi(cl_char a, cl_char b) { return __mul_hi(a, b); }
421
459
cl_short s_mul_hi (cl_short a, cl_short b) { return __mul_hi (a, b); }
422
460
cl_int s_mul_hi (cl_int a, cl_int b) { return __mul_hi (a, b); }
423
461
cl_long s_mul_hi (s::cl_long x, s::cl_long y) __NOEXC {
424
- return __long_mul_hi (x, y);
462
+ return __s_long_mul_hi (x, y);
425
463
}
426
464
MAKE_1V_2V (s_mul_hi, s::cl_char, s::cl_char, s::cl_char)
427
465
MAKE_1V_2V (s_mul_hi, s::cl_short, s::cl_short, s::cl_short)
@@ -433,7 +471,7 @@ cl_uchar u_mul_hi(cl_uchar a, cl_uchar b) { return __mul_hi(a, b); }
433
471
cl_ushort u_mul_hi (cl_ushort a, cl_ushort b) { return __mul_hi (a, b); }
434
472
cl_uint u_mul_hi (cl_uint a, cl_uint b) { return __mul_hi (a, b); }
435
473
cl_ulong u_mul_hi (s::cl_ulong x, s::cl_ulong y) __NOEXC {
436
- return __long_mul_hi (x, y);
474
+ return __u_long_mul_hi (x, y);
437
475
}
438
476
MAKE_1V_2V (u_mul_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar)
439
477
MAKE_1V_2V (u_mul_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort)
@@ -452,7 +490,7 @@ cl_int s_mad_hi(s::cl_int x, s::cl_int minval, s::cl_int maxval) __NOEXC {
452
490
return __mad_hi (x, minval, maxval);
453
491
}
454
492
cl_long s_mad_hi (s::cl_long x, s::cl_long minval, s::cl_long maxval) __NOEXC {
455
- return __long_mad_hi (x, minval, maxval);
493
+ return __s_long_mad_hi (x, minval, maxval);
456
494
}
457
495
MAKE_1V_2V_3V (s_mad_hi, s::cl_char, s::cl_char, s::cl_char, s::cl_char)
458
496
MAKE_1V_2V_3V (s_mad_hi, s::cl_short, s::cl_short, s::cl_short, s::cl_short)
@@ -473,7 +511,7 @@ cl_uint u_mad_hi(s::cl_uint x, s::cl_uint minval, s::cl_uint maxval) __NOEXC {
473
511
}
474
512
cl_ulong u_mad_hi (s::cl_ulong x, s::cl_ulong minval,
475
513
s::cl_ulong maxval) __NOEXC {
476
- return __long_mad_hi (x, minval, maxval);
514
+ return __u_long_mad_hi (x, minval, maxval);
477
515
}
478
516
MAKE_1V_2V_3V (u_mad_hi, s::cl_uchar, s::cl_uchar, s::cl_uchar, s::cl_uchar)
479
517
MAKE_1V_2V_3V (u_mad_hi, s::cl_ushort, s::cl_ushort, s::cl_ushort, s::cl_ushort)
0 commit comments