-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[flang] Fix implementation of Kahan summation #116897
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
In the runtime's implementation of floating-point SUM, the implementation of Kahan's algorithm for increased precision is incorrect. The running correction factor should be subtracted from each new data item, not added to it. This fix ensures that the sum of 100M random default real values between 0. and 1. is close to 5.E7. See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.
@llvm/pr-subscribers-flang-semantics @llvm/pr-subscribers-flang-runtime Author: Peter Klausler (klausler) ChangesIn the runtime's implementation of floating-point SUM, the implementation of Kahan's algorithm for increased precision is incorrect. The running correction factor should be subtracted from each new data item, not added to it. This fix ensures that the sum of 100M random default real values between 0. and 1. is close to 5.E7. See https://en.wikipedia.org/wiki/Kahan_summation_algorithm. Full diff: https://github.com/llvm/llvm-project/pull/116897.diff 4 Files Affected:
diff --git a/flang/lib/Evaluate/fold-matmul.h b/flang/lib/Evaluate/fold-matmul.h
index be9c547d45286c..c3d65a90409098 100644
--- a/flang/lib/Evaluate/fold-matmul.h
+++ b/flang/lib/Evaluate/fold-matmul.h
@@ -61,7 +61,7 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
if constexpr (useKahanSummation) {
- auto next{correction.Add(product.value, rounding)};
+ auto next{product.value.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-real.cpp b/flang/lib/Evaluate/fold-real.cpp
index 0b79a417942a45..6fb5249c8a5e2e 100644
--- a/flang/lib/Evaluate/fold-real.cpp
+++ b/flang/lib/Evaluate/fold-real.cpp
@@ -78,7 +78,7 @@ template <int KIND> class Norm2Accumulator {
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
if constexpr (useKahanSummation) {
- auto next{square.Add(correction_, rounding_)};
+ auto next{square.Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/lib/Evaluate/fold-reduction.h b/flang/lib/Evaluate/fold-reduction.h
index 8ca0794ab0fc7c..b1b81d8740d3f3 100644
--- a/flang/lib/Evaluate/fold-reduction.h
+++ b/flang/lib/Evaluate/fold-reduction.h
@@ -47,7 +47,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -90,7 +90,7 @@ static Expr<T> FoldDotProduct(
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
if constexpr (useKahanSummation) {
- auto next{correction.Add(x, rounding)};
+ auto next{x.Subtract(correction, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
@@ -348,7 +348,7 @@ template <typename T> class SumAccumulator {
overflow_ |= sum.overflow;
element = sum.value;
} else { // Real & Complex: use Kahan summation
- auto next{array_.At(at).Add(correction_, rounding_)};
+ auto next{array_.At(at).Subtract(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
diff --git a/flang/runtime/sum.cpp b/flang/runtime/sum.cpp
index 04241443275eb9..10b81242546521 100644
--- a/flang/runtime/sum.cpp
+++ b/flang/runtime/sum.cpp
@@ -53,7 +53,7 @@ template <typename INTERMEDIATE> class RealSumAccumulator {
}
template <typename A> RT_API_ATTRS bool Accumulate(A x) {
// Kahan summation
- auto next{x + correction_};
+ auto next{x - correction_};
auto oldSum{sum_};
sum_ += next;
correction_ = (sum_ - oldSum) - next; // algebraically zero
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
In the runtime's implementation of floating-point SUM, the implementation of Kahan's algorithm for increased precision is incorrect. The running correction factor should be subtracted from each new data item, not added to it. This fix ensures that the sum of 100M random default real values between 0. and 1. is close to 5.E7.
See https://en.wikipedia.org/wiki/Kahan_summation_algorithm.