Skip to content

Commit 3f59474

Browse files
authored
[flang] Fix implementation of Kahan summation (#116897)
In the runtime's implementation of floating-point SUM, the implementation of Kahan's algorithm for increased precision is incorrect. The running correction factor should be subtracted from each new data item, not added to it. This fix ensures that the sum of 100M random default real values between 0. and 1. is close to 5.E7. See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.
1 parent a76609d commit 3f59474

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

flang/lib/Evaluate/fold-matmul.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
6161
auto product{aElt.Multiply(bElt)};
6262
overflow |= product.flags.test(RealFlag::Overflow);
6363
if constexpr (useKahanSummation) {
64-
auto next{correction.Add(product.value, rounding)};
64+
auto next{product.value.Subtract(correction, rounding)};
6565
overflow |= next.flags.test(RealFlag::Overflow);
6666
auto added{sum.Add(next.value, rounding)};
6767
overflow |= added.flags.test(RealFlag::Overflow);

flang/lib/Evaluate/fold-real.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ template <int KIND> class Norm2Accumulator {
7878
auto scaled{item.Divide(scale).value};
7979
auto square{scaled.Multiply(scaled).value};
8080
if constexpr (useKahanSummation) {
81-
auto next{square.Add(correction_, rounding_)};
81+
auto next{square.Subtract(correction_, rounding_)};
8282
overflow_ |= next.flags.test(RealFlag::Overflow);
8383
auto sum{element.Add(next.value, rounding_)};
8484
overflow_ |= sum.flags.test(RealFlag::Overflow);

flang/lib/Evaluate/fold-reduction.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ static Expr<T> FoldDotProduct(
4747
const auto &rounding{context.targetCharacteristics().roundingMode()};
4848
for (const Element &x : cProducts.values()) {
4949
if constexpr (useKahanSummation) {
50-
auto next{correction.Add(x, rounding)};
50+
auto next{x.Subtract(correction, rounding)};
5151
overflow |= next.flags.test(RealFlag::Overflow);
5252
auto added{sum.Add(next.value, rounding)};
5353
overflow |= added.flags.test(RealFlag::Overflow);
@@ -90,7 +90,7 @@ static Expr<T> FoldDotProduct(
9090
const auto &rounding{context.targetCharacteristics().roundingMode()};
9191
for (const Element &x : cProducts.values()) {
9292
if constexpr (useKahanSummation) {
93-
auto next{correction.Add(x, rounding)};
93+
auto next{x.Subtract(correction, rounding)};
9494
overflow |= next.flags.test(RealFlag::Overflow);
9595
auto added{sum.Add(next.value, rounding)};
9696
overflow |= added.flags.test(RealFlag::Overflow);
@@ -348,7 +348,7 @@ template <typename T> class SumAccumulator {
348348
overflow_ |= sum.overflow;
349349
element = sum.value;
350350
} else { // Real & Complex: use Kahan summation
351-
auto next{array_.At(at).Add(correction_, rounding_)};
351+
auto next{array_.At(at).Subtract(correction_, rounding_)};
352352
overflow_ |= next.flags.test(RealFlag::Overflow);
353353
auto sum{element.Add(next.value, rounding_)};
354354
overflow_ |= sum.flags.test(RealFlag::Overflow);

flang/runtime/sum.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
5353
}
5454
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
5555
// Kahan summation
56-
auto next{x + correction_};
56+
auto next{x - correction_};
5757
auto oldSum{sum_};
5858
sum_ += next;
5959
correction_ = (sum_ - oldSum) - next; // algebraically zero

0 commit comments

Comments
 (0)