Skip to content

Commit 1ef5e6d

Browse files
committed
[flang] Make SQRT folding exact
Replace the latter half of the SQRT() folding algorithm with code that calculates an exact root with extra rounding bits, and then lets the usual normalization and rounding code do the right thing. Extend tests to catch regressions. Differential Revision: https://reviews.llvm.org/D128395
1 parent dfaa388 commit 1ef5e6d

File tree

3 files changed

+61
-45
lines changed

3 files changed

+61
-45
lines changed

flang/lib/Evaluate/real.cpp

Lines changed: 24 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,7 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
274274
// SQRT(-0) == -0 in IEEE-754.
275275
result.value = NegativeZero();
276276
} else {
277+
result.flags.set(RealFlag::InvalidArgument);
277278
result.value = NotANumber();
278279
}
279280
} else if (IsInfinite()) {
@@ -297,53 +298,31 @@ ValueWithRealFlags<Real<W, P>> Real<W, P>::SQRT(Rounding rounding) const {
297298
result.value.GetFraction());
298299
return result;
299300
}
300-
// Compute the square root of the reduced value with the slow but
301-
// reliable bit-at-a-time method. Start with a clear significand and
302-
// half of the unbiased exponent, and then try to set significand bits
303-
// in descending order of magnitude without exceeding the exact result.
304-
expo = expo / 2 + exponentBias;
305-
result.value.Normalize(false, expo, Fraction::MASKL(1));
306-
Real initialSq{result.value.Multiply(result.value).value};
307-
if (Compare(initialSq) == Relation::Less) {
308-
// Initial estimate is too large; this can happen for values just
309-
// under 1.0.
310-
--expo;
311-
result.value.Normalize(false, expo, Fraction::MASKL(1));
312-
}
313-
for (int bit{significandBits - 1}; bit >= 0; --bit) {
314-
Word word{result.value.word_};
315-
result.value.word_ = word.IBSET(bit);
316-
auto squared{result.value.Multiply(result.value, rounding)};
317-
if (squared.flags.test(RealFlag::Overflow) ||
318-
squared.flags.test(RealFlag::Underflow) ||
319-
Compare(squared.value) == Relation::Less) {
320-
result.value.word_ = word;
321-
}
322-
}
323-
// The computed square root has a square that's not greater than the
324-
// original argument. Check this square against the square of the next
325-
// larger Real and return that one if its square is closer in magnitude to
326-
// the original argument.
327-
Real resultSq{result.value.Multiply(result.value).value};
328-
Real diff{Subtract(resultSq).value.ABS()};
329-
if (diff.IsZero()) {
330-
return result; // exact
331-
}
332-
Real ulp;
333-
ulp.Normalize(false, expo, Fraction::MASKR(1));
334-
Real nextAfter{result.value.Add(ulp).value};
335-
auto nextAfterSq{nextAfter.Multiply(nextAfter)};
336-
if (!nextAfterSq.flags.test(RealFlag::Overflow) &&
337-
!nextAfterSq.flags.test(RealFlag::Underflow)) {
338-
Real nextAfterDiff{Subtract(nextAfterSq.value).value.ABS()};
339-
if (nextAfterDiff.Compare(diff) == Relation::Less) {
340-
result.value = nextAfter;
341-
if (nextAfterDiff.IsZero()) {
342-
return result; // exact
343-
}
301+
// (-1) <= expo <= 1; use it as a shift to set the desired square.
302+
using Extended = typename value::Integer<(binaryPrecision + 2)>;
303+
Extended goal{
304+
Extended::ConvertUnsigned(GetFraction()).value.SHIFTL(expo + 1)};
305+
// Calculate the exact square root by maximizing a value whose square
306+
// does not exceed the goal. Use two extra bits of precision for
307+
// rounding.
308+
bool sticky{true};
309+
Extended extFrac{};
310+
for (int bit{Extended::bits - 1}; bit >= 0; --bit) {
311+
Extended next{extFrac.IBSET(bit)};
312+
auto squared{next.MultiplyUnsigned(next)};
313+
auto cmp{squared.upper.CompareUnsigned(goal)};
314+
if (cmp == Ordering::Less) {
315+
extFrac = next;
316+
} else if (cmp == Ordering::Equal && squared.lower.IsZero()) {
317+
extFrac = next;
318+
sticky = false;
319+
break; // exact result
344320
}
345321
}
346-
result.flags.set(RealFlag::Inexact);
322+
RoundingBits roundingBits{extFrac.BTEST(1), extFrac.BTEST(0), sticky};
323+
NormalizeAndRound(result, false, exponentBias,
324+
Fraction::ConvertUnsigned(extFrac.SHIFTR(2)).value, rounding,
325+
roundingBits);
347326
}
348327
return result;
349328
}

flang/test/Evaluate/folding28.f90

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,25 @@ module m
4949
logical, parameter :: test_sqrt_zero_4 = sqrt_zero_4 == 0.0
5050
real(8), parameter :: sqrt_zero_8 = sqrt(0.0)
5151
logical, parameter :: test_sqrt_zero_8 = sqrt_zero_8 == 0.0
52+
! Some common values to get right
53+
real(8), parameter :: sqrt_1_8 = sqrt(1.d0)
54+
logical, parameter :: test_sqrt_1_8 = sqrt_1_8 == 1.d0
55+
real(8), parameter :: sqrt_2_8 = sqrt(2.d0)
56+
logical, parameter :: test_sqrt_2_8 = sqrt_2_8 == 1.4142135623730951454746218587388284504413604736328125d0
57+
real(8), parameter :: sqrt_3_8 = sqrt(3.d0)
58+
logical, parameter :: test_sqrt_3_8 = sqrt_3_8 == 1.732050807568877193176604123436845839023590087890625d0
59+
real(8), parameter :: sqrt_4_8 = sqrt(4.d0)
60+
logical, parameter :: test_sqrt_4_8 = sqrt_4_8 == 2.d0
61+
real(8), parameter :: sqrt_5_8 = sqrt(5.d0)
62+
logical, parameter :: test_sqrt_5_8 = sqrt_5_8 == 2.236067977499789805051477742381393909454345703125d0
63+
real(8), parameter :: sqrt_6_8 = sqrt(6.d0)
64+
logical, parameter :: test_sqrt_6_8 = sqrt_6_8 == 2.44948974278317788133563226438127458095550537109375d0
65+
real(8), parameter :: sqrt_7_8 = sqrt(7.d0)
66+
logical, parameter :: test_sqrt_7_8 = sqrt_7_8 == 2.64575131106459071617109657381661236286163330078125d0
67+
real(8), parameter :: sqrt_8_8 = sqrt(8.d0)
68+
logical, parameter :: test_sqrt_8_8 = sqrt_8_8 == 2.828427124746190290949243717477656900882720947265625d0
69+
real(8), parameter :: sqrt_9_8 = sqrt(9.d0)
70+
logical, parameter :: test_sqrt_9_8 = sqrt_9_8 == 3.d0
71+
real(8), parameter :: sqrt_10_8 = sqrt(10.d0)
72+
logical, parameter :: test_sqrt_10_8 = sqrt_10_8 == 3.162277660168379522787063251598738133907318115234375d0
5273
end module

flang/unittests/Evaluate/real.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,22 @@ void subsetTests(int pass, Rounding rounding, std::uint32_t opds) {
392392
("%d AINT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
393393
}
394394

395+
{
396+
ValueWithRealFlags<REAL> root{x.SQRT(rounding)};
397+
#ifndef __clang__ // broken and also slow
398+
fpenv.ClearFlags();
399+
#endif
400+
FLT fcheck{std::sqrt(fj)};
401+
auto actualFlags{FlagsToBits(fpenv.CurrentFlags())};
402+
u.f = fcheck;
403+
UINT rcheck{NormalizeNaN(u.ui)};
404+
UINT check = root.value.RawBits().ToUInt64();
405+
MATCH(rcheck, check)
406+
("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
407+
MATCH(actualFlags, FlagsToBits(root.flags))
408+
("%d SQRT(0x%jx)", pass, static_cast<std::intmax_t>(rj));
409+
}
410+
395411
{
396412
MATCH(IsNaN(rj), x.IsNotANumber())
397413
("%d IsNaN(0x%jx)", pass, static_cast<std::intmax_t>(rj));

0 commit comments

Comments
 (0)