Skip to content

Commit 21508fa

Browse files
committed
libclc: clspv: fix fma, add vstore and fix inlining issues
https://reviews.llvm.org/D147773 Patch by Romaric Jodin <[email protected]>
1 parent f859835 commit 21508fa

File tree

6 files changed

+298
-124
lines changed

6 files changed

+298
-124
lines changed

libclc/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,11 +271,11 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
271271
set( spvflags --spirv-max-version=1.1 )
272272
elseif( ${ARCH} STREQUAL "clspv" )
273273
set( t "spir--" )
274-
set( build_flags )
274+
set( build_flags "-Wno-unknown-assumption")
275275
set( opt_flags -O3 )
276276
elseif( ${ARCH} STREQUAL "clspv64" )
277277
set( t "spir64--" )
278-
set( build_flags )
278+
set( build_flags "-Wno-unknown-assumption")
279279
set( opt_flags -O3 )
280280
else()
281281
set( build_flags )

libclc/clspv/lib/SOURCES

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
math/fma.cl
22
math/nextafter.cl
3+
shared/vstore_half.cl
34
subnormal_config.cl
45
../../generic/lib/geometric/distance.cl
56
../../generic/lib/geometric/length.cl
@@ -45,6 +46,12 @@ subnormal_config.cl
4546
../../generic/lib/math/frexp.cl
4647
../../generic/lib/math/half_cos.cl
4748
../../generic/lib/math/half_divide.cl
49+
../../generic/lib/math/half_exp.cl
50+
../../generic/lib/math/half_exp10.cl
51+
../../generic/lib/math/half_exp2.cl
52+
../../generic/lib/math/half_log.cl
53+
../../generic/lib/math/half_log10.cl
54+
../../generic/lib/math/half_log2.cl
4855
../../generic/lib/math/half_powr.cl
4956
../../generic/lib/math/half_recip.cl
5057
../../generic/lib/math/half_sin.cl

libclc/clspv/lib/math/fma.cl

Lines changed: 135 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,92 @@ struct fp {
3434
uint sign;
3535
};
3636

37+
static uint2 u2_set(uint hi, uint lo) {
38+
uint2 res;
39+
res.lo = lo;
40+
res.hi = hi;
41+
return res;
42+
}
43+
44+
static uint2 u2_set_u(uint val) { return u2_set(0, val); }
45+
46+
static uint2 u2_mul(uint a, uint b) {
47+
uint2 res;
48+
res.hi = mul_hi(a, b);
49+
res.lo = a * b;
50+
return res;
51+
}
52+
53+
static uint2 u2_sll(uint2 val, uint shift) {
54+
if (shift == 0)
55+
return val;
56+
if (shift < 32) {
57+
val.hi <<= shift;
58+
val.hi |= val.lo >> (32 - shift);
59+
val.lo <<= shift;
60+
} else {
61+
val.hi = val.lo << (shift - 32);
62+
val.lo = 0;
63+
}
64+
return val;
65+
}
66+
67+
static uint2 u2_srl(uint2 val, uint shift) {
68+
if (shift == 0)
69+
return val;
70+
if (shift < 32) {
71+
val.lo >>= shift;
72+
val.lo |= val.hi << (32 - shift);
73+
val.hi >>= shift;
74+
} else {
75+
val.lo = val.hi >> (shift - 32);
76+
val.hi = 0;
77+
}
78+
return val;
79+
}
80+
81+
static uint2 u2_or(uint2 a, uint b) {
82+
a.lo |= b;
83+
return a;
84+
}
85+
86+
static uint2 u2_and(uint2 a, uint2 b) {
87+
a.lo &= b.lo;
88+
a.hi &= b.hi;
89+
return a;
90+
}
91+
92+
static uint2 u2_add(uint2 a, uint2 b) {
93+
uint carry = (hadd(a.lo, b.lo) >> 31) & 0x1;
94+
a.lo += b.lo;
95+
a.hi += b.hi + carry;
96+
return a;
97+
}
98+
99+
static uint2 u2_add_u(uint2 a, uint b) { return u2_add(a, u2_set_u(b)); }
100+
101+
static uint2 u2_inv(uint2 a) {
102+
a.lo = ~a.lo;
103+
a.hi = ~a.hi;
104+
return u2_add_u(a, 1);
105+
}
106+
107+
static uint u2_clz(uint2 a) {
108+
uint leading_zeroes = clz(a.hi);
109+
if (leading_zeroes == 32) {
110+
leading_zeroes += clz(a.lo);
111+
}
112+
return leading_zeroes;
113+
}
114+
115+
static bool u2_eq(uint2 a, uint2 b) { return a.lo == b.lo && a.hi == b.hi; }
116+
117+
static bool u2_zero(uint2 a) { return u2_eq(a, u2_set_u(0)); }
118+
119+
static bool u2_gt(uint2 a, uint2 b) {
120+
return a.hi > b.hi || (a.hi == b.hi && a.lo > b.lo);
121+
}
122+
37123
_CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
38124
/* special cases */
39125
if (isnan(a) || isnan(b) || isnan(c) || isinf(a) || isinf(b)) {
@@ -63,12 +149,9 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
63149
st_b.exponent = b == .0f ? 0 : ((as_uint(b) & 0x7f800000) >> 23) - 127;
64150
st_c.exponent = c == .0f ? 0 : ((as_uint(c) & 0x7f800000) >> 23) - 127;
65151

66-
st_a.mantissa.lo = a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000;
67-
st_b.mantissa.lo = b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000;
68-
st_c.mantissa.lo = c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000;
69-
st_a.mantissa.hi = 0;
70-
st_b.mantissa.hi = 0;
71-
st_c.mantissa.hi = 0;
152+
st_a.mantissa = u2_set_u(a == .0f ? 0 : (as_uint(a) & 0x7fffff) | 0x800000);
153+
st_b.mantissa = u2_set_u(b == .0f ? 0 : (as_uint(b) & 0x7fffff) | 0x800000);
154+
st_c.mantissa = u2_set_u(c == .0f ? 0 : (as_uint(c) & 0x7fffff) | 0x800000);
72155

73156
st_a.sign = as_uint(a) & 0x80000000;
74157
st_b.sign = as_uint(b) & 0x80000000;
@@ -81,162 +164,94 @@ _CLC_DEF _CLC_OVERLOAD float fma(float a, float b, float c) {
81164
// add another bit to detect subtraction underflow
82165
struct fp st_mul;
83166
st_mul.sign = st_a.sign ^ st_b.sign;
84-
st_mul.mantissa.hi = mul_hi(st_a.mantissa.lo, st_b.mantissa.lo);
85-
st_mul.mantissa.lo = st_a.mantissa.lo * st_b.mantissa.lo;
86-
uint upper_14bits = (st_mul.mantissa.lo >> 18) & 0x3fff;
87-
st_mul.mantissa.lo <<= 14;
88-
st_mul.mantissa.hi <<= 14;
89-
st_mul.mantissa.hi |= upper_14bits;
90-
st_mul.exponent = (st_mul.mantissa.lo != 0 || st_mul.mantissa.hi != 0)
91-
? st_a.exponent + st_b.exponent
92-
: 0;
167+
st_mul.mantissa = u2_sll(u2_mul(st_a.mantissa.lo, st_b.mantissa.lo), 14);
168+
st_mul.exponent =
169+
!u2_zero(st_mul.mantissa) ? st_a.exponent + st_b.exponent : 0;
170+
171+
// FIXME: Detecting a == 0 || b == 0 above crashed GCN isel
172+
if (st_mul.exponent == 0 && u2_zero(st_mul.mantissa))
173+
return c;
93174

94175
// Mantissa is 23 fractional bits, shift it the same way as product mantissa
95176
#define C_ADJUST 37ul
96177

97178
// both exponents are bias adjusted
98179
int exp_diff = st_mul.exponent - st_c.exponent;
99180

100-
uint abs_exp_diff = abs(exp_diff);
101-
st_c.mantissa.hi = (st_c.mantissa.lo << 5);
102-
st_c.mantissa.lo = 0;
103-
uint2 cutoff_bits = (uint2)(0, 0);
104-
uint2 cutoff_mask = (uint2)(0, 0);
105-
if (abs_exp_diff < 32) {
106-
cutoff_mask.lo = (1u << abs(exp_diff)) - 1u;
107-
} else if (abs_exp_diff < 64) {
108-
cutoff_mask.lo = 0xffffffff;
109-
uint remaining = abs_exp_diff - 32;
110-
cutoff_mask.hi = (1u << remaining) - 1u;
181+
st_c.mantissa = u2_sll(st_c.mantissa, C_ADJUST);
182+
uint2 cutoff_bits = u2_set_u(0);
183+
uint2 cutoff_mask = u2_add(u2_sll(u2_set_u(1), abs(exp_diff)),
184+
u2_set(0xffffffff, 0xffffffff));
185+
if (exp_diff > 0) {
186+
cutoff_bits =
187+
exp_diff >= 64 ? st_c.mantissa : u2_and(st_c.mantissa, cutoff_mask);
188+
st_c.mantissa =
189+
exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_c.mantissa, exp_diff);
111190
} else {
112-
cutoff_mask = (uint2)(0, 0);
113-
}
114-
uint2 tmp = (exp_diff > 0) ? st_c.mantissa : st_mul.mantissa;
115-
if (abs_exp_diff > 0) {
116-
cutoff_bits = abs_exp_diff >= 64 ? tmp : (tmp & cutoff_mask);
117-
if (abs_exp_diff < 32) {
118-
// shift some of the hi bits into the shifted lo bits.
119-
uint shift_mask = (1u << abs_exp_diff) - 1;
120-
uint upper_saved_bits = tmp.hi & shift_mask;
121-
upper_saved_bits = upper_saved_bits << (32 - abs_exp_diff);
122-
tmp.hi >>= abs_exp_diff;
123-
tmp.lo >>= abs_exp_diff;
124-
tmp.lo |= upper_saved_bits;
125-
} else if (abs_exp_diff < 64) {
126-
tmp.lo = (tmp.hi >> (abs_exp_diff - 32));
127-
tmp.hi = 0;
128-
} else {
129-
tmp = (uint2)(0, 0);
130-
}
191+
cutoff_bits = -exp_diff >= 64 ? st_mul.mantissa
192+
: u2_and(st_mul.mantissa, cutoff_mask);
193+
st_mul.mantissa =
194+
-exp_diff >= 64 ? u2_set_u(0) : u2_srl(st_mul.mantissa, -exp_diff);
131195
}
132-
if (exp_diff > 0)
133-
st_c.mantissa = tmp;
134-
else
135-
st_mul.mantissa = tmp;
136196

137197
struct fp st_fma;
138198
st_fma.sign = st_mul.sign;
139199
st_fma.exponent = max(st_mul.exponent, st_c.exponent);
140-
st_fma.mantissa = (uint2)(0, 0);
141200
if (st_c.sign == st_mul.sign) {
142-
uint carry = (hadd(st_mul.mantissa.lo, st_c.mantissa.lo) >> 31) & 0x1;
143-
st_fma.mantissa = st_mul.mantissa + st_c.mantissa;
144-
st_fma.mantissa.hi += carry;
201+
st_fma.mantissa = u2_add(st_mul.mantissa, st_c.mantissa);
145202
} else {
146203
// cutoff bits borrow one
147-
uint cutoff_borrow = ((cutoff_bits.lo != 0 || cutoff_bits.hi != 0) &&
148-
(st_mul.exponent > st_c.exponent))
149-
? 1
150-
: 0;
151-
uint borrow = 0;
152-
if (st_c.mantissa.lo > st_mul.mantissa.lo) {
153-
borrow = 1;
154-
} else if (st_c.mantissa.lo == UINT_MAX && cutoff_borrow == 1) {
155-
borrow = 1;
156-
} else if ((st_c.mantissa.lo + cutoff_borrow) > st_mul.mantissa.lo) {
157-
borrow = 1;
158-
}
159-
160-
st_fma.mantissa.lo = st_mul.mantissa.lo - st_c.mantissa.lo - cutoff_borrow;
161-
st_fma.mantissa.hi = st_mul.mantissa.hi - st_c.mantissa.hi - borrow;
204+
st_fma.mantissa =
205+
u2_add(u2_add(st_mul.mantissa, u2_inv(st_c.mantissa)),
206+
(!u2_zero(cutoff_bits) && (st_mul.exponent > st_c.exponent)
207+
? u2_set(0xffffffff, 0xffffffff)
208+
: u2_set_u(0)));
162209
}
163210

164211
// underflow: st_c.sign != st_mul.sign, and magnitude switches the sign
165-
if (st_fma.mantissa.hi > INT_MAX) {
166-
st_fma.mantissa = ~st_fma.mantissa;
167-
uint carry = (hadd(st_fma.mantissa.lo, 1u) >> 31) & 0x1;
168-
st_fma.mantissa.lo += 1;
169-
st_fma.mantissa.hi += carry;
170-
212+
if (u2_gt(st_fma.mantissa, u2_set(0x7fffffff, 0xffffffff))) {
213+
st_fma.mantissa = u2_inv(st_fma.mantissa);
171214
st_fma.sign = st_mul.sign ^ 0x80000000;
172215
}
173216

174217
// detect overflow/underflow
175-
uint leading_zeroes = clz(st_fma.mantissa.hi);
176-
if (leading_zeroes == 32) {
177-
leading_zeroes += clz(st_fma.mantissa.lo);
178-
}
179-
int overflow_bits = 3 - leading_zeroes;
218+
int overflow_bits = 3 - u2_clz(st_fma.mantissa);
180219

181220
// adjust exponent
182221
st_fma.exponent += overflow_bits;
183222

184223
// handle underflow
185224
if (overflow_bits < 0) {
186-
uint shift = -overflow_bits;
187-
if (shift < 32) {
188-
uint shift_mask = (1u << shift) - 1;
189-
uint saved_lo_bits = (st_fma.mantissa.lo >> (32 - shift)) & shift_mask;
190-
st_fma.mantissa.lo <<= shift;
191-
st_fma.mantissa.hi <<= shift;
192-
st_fma.mantissa.hi |= saved_lo_bits;
193-
} else if (shift < 64) {
194-
st_fma.mantissa.hi = (st_fma.mantissa.lo << (64 - shift));
195-
st_fma.mantissa.lo = 0;
196-
} else {
197-
st_fma.mantissa = (uint2)(0, 0);
198-
}
199-
225+
st_fma.mantissa = u2_sll(st_fma.mantissa, -overflow_bits);
200226
overflow_bits = 0;
201227
}
202228

203229
// rounding
204-
// overflow_bits is now in the range of [0, 3] making the shift greater than
205-
// 32 bits.
206-
uint2 trunc_mask;
207-
uint trunc_shift = C_ADJUST + overflow_bits - 32;
208-
trunc_mask.hi = (1u << trunc_shift) - 1;
209-
trunc_mask.lo = UINT_MAX;
210-
uint2 trunc_bits = st_fma.mantissa & trunc_mask;
211-
trunc_bits.lo |= (cutoff_bits.hi != 0 || cutoff_bits.lo != 0) ? 1 : 0;
212-
uint2 last_bit;
213-
last_bit.lo = 0;
214-
last_bit.hi = st_fma.mantissa.hi & (1u << trunc_shift);
215-
uint grs_shift = C_ADJUST - 3 + overflow_bits - 32;
216-
uint2 grs_bits;
217-
grs_bits.lo = 0;
218-
grs_bits.hi = 0x4u << grs_shift;
230+
uint2 trunc_mask = u2_add(u2_sll(u2_set_u(1), C_ADJUST + overflow_bits),
231+
u2_set(0xffffffff, 0xffffffff));
232+
uint2 trunc_bits =
233+
u2_or(u2_and(st_fma.mantissa, trunc_mask), !u2_zero(cutoff_bits));
234+
uint2 last_bit =
235+
u2_and(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
236+
uint2 grs_bits = u2_sll(u2_set_u(4), C_ADJUST - 3 + overflow_bits);
219237

220238
// round to nearest even
221-
if ((trunc_bits.hi > grs_bits.hi ||
222-
(trunc_bits.hi == grs_bits.hi && trunc_bits.lo > grs_bits.lo)) ||
223-
(trunc_bits.hi == grs_bits.hi && trunc_bits.lo == grs_bits.lo &&
224-
last_bit.hi != 0)) {
225-
uint shift = C_ADJUST + overflow_bits - 32;
226-
st_fma.mantissa.hi += 1u << shift;
239+
if (u2_gt(trunc_bits, grs_bits) ||
240+
(u2_eq(trunc_bits, grs_bits) && !u2_zero(last_bit))) {
241+
st_fma.mantissa =
242+
u2_add(st_fma.mantissa, u2_sll(u2_set_u(1), C_ADJUST + overflow_bits));
227243
}
228244

229-
// Shift mantissa back to bit 23
230-
st_fma.mantissa.lo = (st_fma.mantissa.hi >> (C_ADJUST + overflow_bits - 32));
231-
st_fma.mantissa.hi = 0;
245+
// Shift mantissa back to bit 23
246+
st_fma.mantissa = u2_srl(st_fma.mantissa, C_ADJUST + overflow_bits);
232247

233248
// Detect rounding overflow
234-
if (st_fma.mantissa.lo > 0xffffff) {
249+
if (u2_gt(st_fma.mantissa, u2_set_u(0xffffff))) {
235250
++st_fma.exponent;
236-
st_fma.mantissa.lo >>= 1;
251+
st_fma.mantissa = u2_srl(st_fma.mantissa, 1);
237252
}
238253

239-
if (st_fma.mantissa.lo == 0) {
254+
if (u2_zero(st_fma.mantissa)) {
240255
return 0.0f;
241256
}
242257

0 commit comments

Comments
 (0)