Skip to content

Commit b225934

Browse files
authored
[flang] Avoid needless overflow when folding NORM2 (#67499)
The code that folds the relatively new NORM2 intrinsic function can produce overflow in cases where it's not warranted. Rearrange to NORM2 = M * SQRT((A(:)/M)**2) where M is MAXVAL(ABS(A)).
1 parent 39f4ec5 commit b225934

File tree

3 files changed

+33
-10
lines changed

3 files changed

+33
-10
lines changed

flang/lib/Evaluate/fold-real.cpp

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,15 +52,28 @@ template <int KIND> class Norm2Accumulator {
5252
const Constant<T> &array, const Constant<T> &maxAbs, Rounding rounding)
5353
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
5454
void operator()(Scalar<T> &element, const ConstantSubscripts &at) {
55-
// Kahan summation of scaled elements
55+
// Kahan summation of scaled elements:
56+
// Naively,
57+
// NORM2(A(:)) = SQRT(SUM(A(:)**2))
58+
// For any T > 0, we have mathematically
59+
// SQRT(SUM(A(:)**2))
60+
// = SQRT(T**2 * (SUM(A(:)**2) / T**2))
61+
// = SQRT(T**2 * SUM(A(:)**2 / T**2))
62+
// = SQRT(T**2 * SUM((A(:)/T)**2))
63+
// = SQRT(T**2) * SQRT(SUM((A(:)/T)**2))
64+
// = T * SQRT(SUM((A(:)/T)**2))
65+
// By letting T = MAXVAL(ABS(A)), we ensure that
66+
// ALL(ABS(A(:)/T) <= 1), so ALL((A(:)/T)**2 <= 1), and the SUM will
67+
// not overflow unless absolutely necessary.
5668
auto scale{maxAbs_.At(maxAbsAt_)};
5769
if (scale.IsZero()) {
58-
// If maxAbs is zero, so are all elements, and result
70+
// Maximum value is zero, and so will the result be.
71+
// Avoid division by zero below.
5972
element = scale;
6073
} else {
6174
auto item{array_.At(at)};
6275
auto scaled{item.Divide(scale).value};
63-
auto square{item.Multiply(scaled).value};
76+
auto square{scaled.Multiply(scaled).value};
6477
auto next{square.Add(correction_, rounding_)};
6578
overflow_ |= next.flags.test(RealFlag::Overflow);
6679
auto sum{element.Add(next.value, rounding_)};
@@ -73,13 +86,16 @@ template <int KIND> class Norm2Accumulator {
7386
}
7487
bool overflow() const { return overflow_; }
7588
void Done(Scalar<T> &result) {
89+
// result+correction == SUM((data(:)/maxAbs)**2)
90+
// result = maxAbs * SQRT(result+correction)
7691
auto corrected{result.Add(correction_, rounding_)};
7792
overflow_ |= corrected.flags.test(RealFlag::Overflow);
7893
correction_ = Scalar<T>{};
79-
auto rescaled{corrected.value.Multiply(maxAbs_.At(maxAbsAt_))};
94+
auto root{corrected.value.SQRT().value};
95+
auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
8096
maxAbs_.IncrementSubscripts(maxAbsAt_);
81-
overflow_ |= rescaled.flags.test(RealFlag::Overflow);
82-
result = rescaled.value.SQRT().value;
97+
overflow_ |= product.flags.test(RealFlag::Overflow);
98+
result = product.value;
8399
}
84100

85101
private:

flang/lib/Evaluate/fold-reduction.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ template <typename T, bool ABS = false> class MaxvalMinvalAccumulator {
228228
test.Rewrite(context_, std::move(test)))};
229229
CHECK(folded.has_value());
230230
if (folded->IsTrue()) {
231-
element = array_.At(at);
231+
element = aAt;
232232
}
233233
}
234234
void Done(Scalar<T> &) const {}

flang/test/Evaluate/fold-norm2.f90

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,20 @@ module m
1717
real(dp), parameter :: a(3,4) = &
1818
reshape([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], shape(a))
1919
real(dp), parameter :: nAll = norm2(a)
20-
real(dp), parameter :: check_nAll = sqrt(sum(a * a))
20+
real(dp), parameter :: check_nAll = 11._dp * sqrt(sum((a/11._dp)**2))
2121
logical, parameter :: test_all = nAll == check_nAll
2222
real(dp), parameter :: norms1(4) = norm2(a, dim=1)
23-
real(dp), parameter :: check_norms1(4) = sqrt(sum(a * a, dim=1))
23+
real(dp), parameter :: check_norms1(4) = [ &
24+
2.236067977499789805051477742381393909454345703125_8, &
25+
7.07106781186547550532850436866283416748046875_8, &
26+
1.2206555615733702069292121450416743755340576171875e1_8, &
27+
1.7378147196982769884243680280633270740509033203125e1_8 ]
2428
logical, parameter :: test_norms1 = all(norms1 == check_norms1)
2529
real(dp), parameter :: norms2(3) = norm2(a, dim=2)
26-
real(dp), parameter :: check_norms2(3) = sqrt(sum(a * a, dim=2))
30+
real(dp), parameter :: check_norms2(3) = [ &
31+
1.1224972160321822656214862945489585399627685546875e1_8, &
32+
1.28840987267251261272349438513629138469696044921875e1_8, &
33+
1.4628738838327791427218471653759479522705078125e1_8 ]
2734
logical, parameter :: test_norms2 = all(norms2 == check_norms2)
2835
logical, parameter :: test_normZ = norm2([0.,0.,0.]) == 0.
2936
end

0 commit comments

Comments
 (0)