@@ -34,6 +34,92 @@ struct fp {
34
34
uint sign ;
35
35
};
36
36
37
+ static uint2 u2_set (uint hi , uint lo ) {
38
+ uint2 res ;
39
+ res .lo = lo ;
40
+ res .hi = hi ;
41
+ return res ;
42
+ }
43
+
44
+ static uint2 u2_set_u (uint val ) { return u2_set (0 , val ); }
45
+
46
+ static uint2 u2_mul (uint a , uint b ) {
47
+ uint2 res ;
48
+ res .hi = mul_hi (a , b );
49
+ res .lo = a * b ;
50
+ return res ;
51
+ }
52
+
53
+ static uint2 u2_sll (uint2 val , uint shift ) {
54
+ if (shift == 0 )
55
+ return val ;
56
+ if (shift < 32 ) {
57
+ val .hi <<= shift ;
58
+ val .hi |= val .lo >> (32 - shift );
59
+ val .lo <<= shift ;
60
+ } else {
61
+ val .hi = val .lo << (shift - 32 );
62
+ val .lo = 0 ;
63
+ }
64
+ return val ;
65
+ }
66
+
67
+ static uint2 u2_srl (uint2 val , uint shift ) {
68
+ if (shift == 0 )
69
+ return val ;
70
+ if (shift < 32 ) {
71
+ val .lo >>= shift ;
72
+ val .lo |= val .hi << (32 - shift );
73
+ val .hi >>= shift ;
74
+ } else {
75
+ val .lo = val .hi >> (shift - 32 );
76
+ val .hi = 0 ;
77
+ }
78
+ return val ;
79
+ }
80
+
81
+ static uint2 u2_or (uint2 a , uint b ) {
82
+ a .lo |= b ;
83
+ return a ;
84
+ }
85
+
86
+ static uint2 u2_and (uint2 a , uint2 b ) {
87
+ a .lo &= b .lo ;
88
+ a .hi &= b .hi ;
89
+ return a ;
90
+ }
91
+
92
+ static uint2 u2_add (uint2 a , uint2 b ) {
93
+ uint carry = (hadd (a .lo , b .lo ) >> 31 ) & 0x1 ;
94
+ a .lo += b .lo ;
95
+ a .hi += b .hi + carry ;
96
+ return a ;
97
+ }
98
+
99
+ static uint2 u2_add_u (uint2 a , uint b ) { return u2_add (a , u2_set_u (b )); }
100
+
101
+ static uint2 u2_inv (uint2 a ) {
102
+ a .lo = ~a .lo ;
103
+ a .hi = ~a .hi ;
104
+ return u2_add_u (a , 1 );
105
+ }
106
+
107
+ static uint u2_clz (uint2 a ) {
108
+ uint leading_zeroes = clz (a .hi );
109
+ if (leading_zeroes == 32 ) {
110
+ leading_zeroes += clz (a .lo );
111
+ }
112
+ return leading_zeroes ;
113
+ }
114
+
115
+ static bool u2_eq (uint2 a , uint2 b ) { return a .lo == b .lo && a .hi == b .hi ; }
116
+
117
+ static bool u2_zero (uint2 a ) { return u2_eq (a , u2_set_u (0 )); }
118
+
119
+ static bool u2_gt (uint2 a , uint2 b ) {
120
+ return a .hi > b .hi || (a .hi == b .hi && a .lo > b .lo );
121
+ }
122
+
37
123
_CLC_DEF _CLC_OVERLOAD float fma (float a , float b , float c ) {
38
124
/* special cases */
39
125
if (isnan (a ) || isnan (b ) || isnan (c ) || isinf (a ) || isinf (b )) {
@@ -63,12 +149,9 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
63
149
st_b .exponent = b == .0f ? 0 : ((as_uint (b ) & 0x7f800000 ) >> 23 ) - 127 ;
64
150
st_c .exponent = c == .0f ? 0 : ((as_uint (c ) & 0x7f800000 ) >> 23 ) - 127 ;
65
151
66
- st_a .mantissa .lo = a == .0f ? 0 : (as_uint (a ) & 0x7fffff ) | 0x800000 ;
67
- st_b .mantissa .lo = b == .0f ? 0 : (as_uint (b ) & 0x7fffff ) | 0x800000 ;
68
- st_c .mantissa .lo = c == .0f ? 0 : (as_uint (c ) & 0x7fffff ) | 0x800000 ;
69
- st_a .mantissa .hi = 0 ;
70
- st_b .mantissa .hi = 0 ;
71
- st_c .mantissa .hi = 0 ;
152
+ st_a .mantissa = u2_set_u (a == .0f ? 0 : (as_uint (a ) & 0x7fffff ) | 0x800000 );
153
+ st_b .mantissa = u2_set_u (b == .0f ? 0 : (as_uint (b ) & 0x7fffff ) | 0x800000 );
154
+ st_c .mantissa = u2_set_u (c == .0f ? 0 : (as_uint (c ) & 0x7fffff ) | 0x800000 );
72
155
73
156
st_a .sign = as_uint (a ) & 0x80000000 ;
74
157
st_b .sign = as_uint (b ) & 0x80000000 ;
@@ -81,162 +164,94 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
81
164
// add another bit to detect subtraction underflow
82
165
struct fp st_mul ;
83
166
st_mul .sign = st_a .sign ^ st_b .sign ;
84
- st_mul .mantissa .hi = mul_hi (st_a .mantissa .lo , st_b .mantissa .lo );
85
- st_mul .mantissa .lo = st_a .mantissa .lo * st_b .mantissa .lo ;
86
- uint upper_14bits = (st_mul .mantissa .lo >> 18 ) & 0x3fff ;
87
- st_mul .mantissa .lo <<= 14 ;
88
- st_mul .mantissa .hi <<= 14 ;
89
- st_mul .mantissa .hi |= upper_14bits ;
90
- st_mul .exponent = (st_mul .mantissa .lo != 0 || st_mul .mantissa .hi != 0 )
91
- ? st_a .exponent + st_b .exponent
92
- : 0 ;
167
+ st_mul .mantissa = u2_sll (u2_mul (st_a .mantissa .lo , st_b .mantissa .lo ), 14 );
168
+ st_mul .exponent =
169
+ !u2_zero (st_mul .mantissa ) ? st_a .exponent + st_b .exponent : 0 ;
170
+
171
+ // FIXME: Detecting a == 0 || b == 0 above crashed GCN isel
172
+ if (st_mul .exponent == 0 && u2_zero (st_mul .mantissa ))
173
+ return c ;
93
174
94
175
// Mantissa is 23 fractional bits, shift it the same way as product mantissa
95
176
#define C_ADJUST 37ul
96
177
97
178
// both exponents are bias adjusted
98
179
int exp_diff = st_mul .exponent - st_c .exponent ;
99
180
100
- uint abs_exp_diff = abs (exp_diff );
101
- st_c .mantissa .hi = (st_c .mantissa .lo << 5 );
102
- st_c .mantissa .lo = 0 ;
103
- uint2 cutoff_bits = (uint2 )(0 , 0 );
104
- uint2 cutoff_mask = (uint2 )(0 , 0 );
105
- if (abs_exp_diff < 32 ) {
106
- cutoff_mask .lo = (1u << abs (exp_diff )) - 1u ;
107
- } else if (abs_exp_diff < 64 ) {
108
- cutoff_mask .lo = 0xffffffff ;
109
- uint remaining = abs_exp_diff - 32 ;
110
- cutoff_mask .hi = (1u << remaining ) - 1u ;
181
+ st_c .mantissa = u2_sll (st_c .mantissa , C_ADJUST );
182
+ uint2 cutoff_bits = u2_set_u (0 );
183
+ uint2 cutoff_mask = u2_add (u2_sll (u2_set_u (1 ), abs (exp_diff )),
184
+ u2_set (0xffffffff , 0xffffffff ));
185
+ if (exp_diff > 0 ) {
186
+ cutoff_bits =
187
+ exp_diff >= 64 ? st_c .mantissa : u2_and (st_c .mantissa , cutoff_mask );
188
+ st_c .mantissa =
189
+ exp_diff >= 64 ? u2_set_u (0 ) : u2_srl (st_c .mantissa , exp_diff );
111
190
} else {
112
- cutoff_mask = (uint2 )(0 , 0 );
113
- }
114
- uint2 tmp = (exp_diff > 0 ) ? st_c .mantissa : st_mul .mantissa ;
115
- if (abs_exp_diff > 0 ) {
116
- cutoff_bits = abs_exp_diff >= 64 ? tmp : (tmp & cutoff_mask );
117
- if (abs_exp_diff < 32 ) {
118
- // shift some of the hi bits into the shifted lo bits.
119
- uint shift_mask = (1u << abs_exp_diff ) - 1 ;
120
- uint upper_saved_bits = tmp .hi & shift_mask ;
121
- upper_saved_bits = upper_saved_bits << (32 - abs_exp_diff );
122
- tmp .hi >>= abs_exp_diff ;
123
- tmp .lo >>= abs_exp_diff ;
124
- tmp .lo |= upper_saved_bits ;
125
- } else if (abs_exp_diff < 64 ) {
126
- tmp .lo = (tmp .hi >> (abs_exp_diff - 32 ));
127
- tmp .hi = 0 ;
128
- } else {
129
- tmp = (uint2 )(0 , 0 );
130
- }
191
+ cutoff_bits = - exp_diff >= 64 ? st_mul .mantissa
192
+ : u2_and (st_mul .mantissa , cutoff_mask );
193
+ st_mul .mantissa =
194
+ - exp_diff >= 64 ? u2_set_u (0 ) : u2_srl (st_mul .mantissa , - exp_diff );
131
195
}
132
- if (exp_diff > 0 )
133
- st_c .mantissa = tmp ;
134
- else
135
- st_mul .mantissa = tmp ;
136
196
137
197
struct fp st_fma ;
138
198
st_fma .sign = st_mul .sign ;
139
199
st_fma .exponent = max (st_mul .exponent , st_c .exponent );
140
- st_fma .mantissa = (uint2 )(0 , 0 );
141
200
if (st_c .sign == st_mul .sign ) {
142
- uint carry = (hadd (st_mul .mantissa .lo , st_c .mantissa .lo ) >> 31 ) & 0x1 ;
143
- st_fma .mantissa = st_mul .mantissa + st_c .mantissa ;
144
- st_fma .mantissa .hi += carry ;
201
+ st_fma .mantissa = u2_add (st_mul .mantissa , st_c .mantissa );
145
202
} else {
146
203
// cutoff bits borrow one
147
- uint cutoff_borrow = ((cutoff_bits .lo != 0 || cutoff_bits .hi != 0 ) &&
148
- (st_mul .exponent > st_c .exponent ))
149
- ? 1
150
- : 0 ;
151
- uint borrow = 0 ;
152
- if (st_c .mantissa .lo > st_mul .mantissa .lo ) {
153
- borrow = 1 ;
154
- } else if (st_c .mantissa .lo == UINT_MAX && cutoff_borrow == 1 ) {
155
- borrow = 1 ;
156
- } else if ((st_c .mantissa .lo + cutoff_borrow ) > st_mul .mantissa .lo ) {
157
- borrow = 1 ;
158
- }
159
-
160
- st_fma .mantissa .lo = st_mul .mantissa .lo - st_c .mantissa .lo - cutoff_borrow ;
161
- st_fma .mantissa .hi = st_mul .mantissa .hi - st_c .mantissa .hi - borrow ;
204
+ st_fma .mantissa =
205
+ u2_add (u2_add (st_mul .mantissa , u2_inv (st_c .mantissa )),
206
+ (!u2_zero (cutoff_bits ) && (st_mul .exponent > st_c .exponent )
207
+ ? u2_set (0xffffffff , 0xffffffff )
208
+ : u2_set_u (0 )));
162
209
}
163
210
164
211
// underflow: st_c.sign != st_mul.sign, and magnitude switches the sign
165
- if (st_fma .mantissa .hi > INT_MAX ) {
166
- st_fma .mantissa = ~st_fma .mantissa ;
167
- uint carry = (hadd (st_fma .mantissa .lo , 1u ) >> 31 ) & 0x1 ;
168
- st_fma .mantissa .lo += 1 ;
169
- st_fma .mantissa .hi += carry ;
170
-
212
+ if (u2_gt (st_fma .mantissa , u2_set (0x7fffffff , 0xffffffff ))) {
213
+ st_fma .mantissa = u2_inv (st_fma .mantissa );
171
214
st_fma .sign = st_mul .sign ^ 0x80000000 ;
172
215
}
173
216
174
217
// detect overflow/underflow
175
- uint leading_zeroes = clz (st_fma .mantissa .hi );
176
- if (leading_zeroes == 32 ) {
177
- leading_zeroes += clz (st_fma .mantissa .lo );
178
- }
179
- int overflow_bits = 3 - leading_zeroes ;
218
+ int overflow_bits = 3 - u2_clz (st_fma .mantissa );
180
219
181
220
// adjust exponent
182
221
st_fma .exponent += overflow_bits ;
183
222
184
223
// handle underflow
185
224
if (overflow_bits < 0 ) {
186
- uint shift = - overflow_bits ;
187
- if (shift < 32 ) {
188
- uint shift_mask = (1u << shift ) - 1 ;
189
- uint saved_lo_bits = (st_fma .mantissa .lo >> (32 - shift )) & shift_mask ;
190
- st_fma .mantissa .lo <<= shift ;
191
- st_fma .mantissa .hi <<= shift ;
192
- st_fma .mantissa .hi |= saved_lo_bits ;
193
- } else if (shift < 64 ) {
194
- st_fma .mantissa .hi = (st_fma .mantissa .lo << (64 - shift ));
195
- st_fma .mantissa .lo = 0 ;
196
- } else {
197
- st_fma .mantissa = (uint2 )(0 , 0 );
198
- }
199
-
225
+ st_fma .mantissa = u2_sll (st_fma .mantissa , - overflow_bits );
200
226
overflow_bits = 0 ;
201
227
}
202
228
203
229
// rounding
204
- // overflow_bits is now in the range of [0, 3] making the shift greater than
205
- // 32 bits.
206
- uint2 trunc_mask ;
207
- uint trunc_shift = C_ADJUST + overflow_bits - 32 ;
208
- trunc_mask .hi = (1u << trunc_shift ) - 1 ;
209
- trunc_mask .lo = UINT_MAX ;
210
- uint2 trunc_bits = st_fma .mantissa & trunc_mask ;
211
- trunc_bits .lo |= (cutoff_bits .hi != 0 || cutoff_bits .lo != 0 ) ? 1 : 0 ;
212
- uint2 last_bit ;
213
- last_bit .lo = 0 ;
214
- last_bit .hi = st_fma .mantissa .hi & (1u << trunc_shift );
215
- uint grs_shift = C_ADJUST - 3 + overflow_bits - 32 ;
216
- uint2 grs_bits ;
217
- grs_bits .lo = 0 ;
218
- grs_bits .hi = 0x4u << grs_shift ;
230
+ uint2 trunc_mask = u2_add (u2_sll (u2_set_u (1 ), C_ADJUST + overflow_bits ),
231
+ u2_set (0xffffffff , 0xffffffff ));
232
+ uint2 trunc_bits =
233
+ u2_or (u2_and (st_fma .mantissa , trunc_mask ), !u2_zero (cutoff_bits ));
234
+ uint2 last_bit =
235
+ u2_and (st_fma .mantissa , u2_sll (u2_set_u (1 ), C_ADJUST + overflow_bits ));
236
+ uint2 grs_bits = u2_sll (u2_set_u (4 ), C_ADJUST - 3 + overflow_bits );
219
237
220
238
// round to nearest even
221
- if ((trunc_bits .hi > grs_bits .hi ||
222
- (trunc_bits .hi == grs_bits .hi && trunc_bits .lo > grs_bits .lo )) ||
223
- (trunc_bits .hi == grs_bits .hi && trunc_bits .lo == grs_bits .lo &&
224
- last_bit .hi != 0 )) {
225
- uint shift = C_ADJUST + overflow_bits - 32 ;
226
- st_fma .mantissa .hi += 1u << shift ;
239
+ if (u2_gt (trunc_bits , grs_bits ) ||
240
+ (u2_eq (trunc_bits , grs_bits ) && !u2_zero (last_bit ))) {
241
+ st_fma .mantissa =
242
+ u2_add (st_fma .mantissa , u2_sll (u2_set_u (1 ), C_ADJUST + overflow_bits ));
227
243
}
228
244
229
- // Shift mantissa back to bit 23
230
- st_fma .mantissa .lo = (st_fma .mantissa .hi >> (C_ADJUST + overflow_bits - 32 ));
231
- st_fma .mantissa .hi = 0 ;
245
+ // Shift mantissa back to bit 23
246
+ st_fma .mantissa = u2_srl (st_fma .mantissa , C_ADJUST + overflow_bits );
232
247
233
248
// Detect rounding overflow
234
- if (st_fma .mantissa . lo > 0xffffff ) {
249
+ if (u2_gt ( st_fma .mantissa , u2_set_u ( 0xffffff )) ) {
235
250
++ st_fma .exponent ;
236
- st_fma .mantissa . lo >>= 1 ;
251
+ st_fma .mantissa = u2_srl ( st_fma . mantissa , 1 ) ;
237
252
}
238
253
239
- if (st_fma .mantissa . lo == 0 ) {
254
+ if (u2_zero ( st_fma .mantissa ) ) {
240
255
return 0.0f ;
241
256
}
242
257
0 commit comments