Skip to content

Commit 3502d34

Browse files
authored
[flang] Adjust transformational folding to match runtime (#90132)
The transformational intrinsic functions MATMUL, DOT_PRODUCT, and NORM2 all involve summing up intermediate products into accumulators. In the constant folding library, this is done with extended precision Kahan summation for REAL and COMPLEX arguments, but in the runtime implementations it is not, and this leads to discrepancies between folded results and dynamic results. Disable the use of Kahan summation in folding to resolve these discrepancies, but don't discard the code, in case we want to add Kahan summation in the runtime for some or all of these intrinsic functions.
1 parent a1c1279 commit 3502d34

File tree

4 files changed

+71
-43
lines changed

4 files changed

+71
-43
lines changed

flang/lib/Evaluate/fold-implementation.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@
4545

4646
namespace Fortran::evaluate {
4747

48+
// Don't use Kahan extended precision summation any more when folding
49+
// transformational intrinsic functions other than SUM, since it is
50+
// not used in the runtime implementations of those functions and we
51+
// want results to match.
52+
static constexpr bool useKahanSummation{false};
53+
4854
// Utilities
4955
template <typename T> class Folder {
5056
public:

flang/lib/Evaluate/fold-matmul.h

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -58,18 +58,25 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
5858
Element bElt{mb->At(bAt)};
5959
if constexpr (T::category == TypeCategory::Real ||
6060
T::category == TypeCategory::Complex) {
61-
// Kahan summation
62-
auto product{aElt.Multiply(bElt, rounding)};
61+
auto product{aElt.Multiply(bElt)};
6362
overflow |= product.flags.test(RealFlag::Overflow);
64-
auto next{correction.Add(product.value, rounding)};
65-
overflow |= next.flags.test(RealFlag::Overflow);
66-
auto added{sum.Add(next.value, rounding)};
67-
overflow |= added.flags.test(RealFlag::Overflow);
68-
correction = added.value.Subtract(sum, rounding)
69-
.value.Subtract(next.value, rounding)
70-
.value;
71-
sum = std::move(added.value);
63+
if constexpr (useKahanSummation) {
64+
auto next{correction.Add(product.value, rounding)};
65+
overflow |= next.flags.test(RealFlag::Overflow);
66+
auto added{sum.Add(next.value, rounding)};
67+
overflow |= added.flags.test(RealFlag::Overflow);
68+
correction = added.value.Subtract(sum, rounding)
69+
.value.Subtract(next.value, rounding)
70+
.value;
71+
sum = std::move(added.value);
72+
} else {
73+
auto added{sum.Add(product.value)};
74+
overflow |= added.flags.test(RealFlag::Overflow);
75+
sum = std::move(added.value);
76+
}
7277
} else if constexpr (T::category == TypeCategory::Integer) {
78+
// Don't use Kahan summation in numeric MATMUL folding;
79+
// the runtime doesn't use it, and results should match.
7380
auto product{aElt.MultiplySigned(bElt)};
7481
overflow |= product.SignedMultiplicationOverflowed();
7582
auto added{sum.AddSigned(product.lower)};

flang/lib/Evaluate/fold-real.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ template <int KIND> class Norm2Accumulator {
5454
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
5555
void operator()(
5656
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
57-
// Kahan summation of scaled elements:
57+
// Summation of scaled elements:
5858
// Naively,
5959
// NORM2(A(:)) = SQRT(SUM(A(:)**2))
6060
// For any T > 0, we have mathematically
@@ -76,24 +76,27 @@ template <int KIND> class Norm2Accumulator {
7676
auto item{array_.At(at)};
7777
auto scaled{item.Divide(scale).value};
7878
auto square{scaled.Multiply(scaled).value};
79-
auto next{square.Add(correction_, rounding_)};
80-
overflow_ |= next.flags.test(RealFlag::Overflow);
81-
auto sum{element.Add(next.value, rounding_)};
82-
overflow_ |= sum.flags.test(RealFlag::Overflow);
83-
correction_ = sum.value.Subtract(element, rounding_)
84-
.value.Subtract(next.value, rounding_)
85-
.value;
86-
element = sum.value;
79+
if constexpr (useKahanSummation) {
80+
auto next{square.Add(correction_, rounding_)};
81+
overflow_ |= next.flags.test(RealFlag::Overflow);
82+
auto sum{element.Add(next.value, rounding_)};
83+
overflow_ |= sum.flags.test(RealFlag::Overflow);
84+
correction_ = sum.value.Subtract(element, rounding_)
85+
.value.Subtract(next.value, rounding_)
86+
.value;
87+
element = sum.value;
88+
} else {
89+
auto sum{element.Add(square, rounding_)};
90+
overflow_ |= sum.flags.test(RealFlag::Overflow);
91+
element = sum.value;
92+
}
8793
}
8894
}
8995
bool overflow() const { return overflow_; }
9096
void Done(Scalar<T> &result) {
91-
// result+correction == SUM((data(:)/maxAbs)**2)
92-
// result = maxAbs * SQRT(result+correction)
93-
auto corrected{result.Add(correction_, rounding_)};
94-
overflow_ |= corrected.flags.test(RealFlag::Overflow);
95-
correction_ = Scalar<T>{};
96-
auto root{corrected.value.SQRT().value};
97+
// incoming result = SUM((data(:)/maxAbs)**2)
98+
// outgoing result = maxAbs * SQRT(result)
99+
auto root{result.SQRT().value};
97100
auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
98101
maxAbs_.IncrementSubscripts(maxAbsAt_);
99102
overflow_ |= product.flags.test(RealFlag::Overflow);

flang/lib/Evaluate/fold-reduction.h

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,23 @@ static Expr<T> FoldDotProduct(
4343
Expr<T> products{Fold(
4444
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
4545
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
46-
Element correction{}; // Use Kahan summation for greater precision.
46+
[[maybe_unused]] Element correction{};
4747
const auto &rounding{context.targetCharacteristics().roundingMode()};
4848
for (const Element &x : cProducts.values()) {
49-
auto next{correction.Add(x, rounding)};
50-
overflow |= next.flags.test(RealFlag::Overflow);
51-
auto added{sum.Add(next.value, rounding)};
52-
overflow |= added.flags.test(RealFlag::Overflow);
53-
correction = added.value.Subtract(sum, rounding)
54-
.value.Subtract(next.value, rounding)
55-
.value;
56-
sum = std::move(added.value);
49+
if constexpr (useKahanSummation) {
50+
auto next{correction.Add(x, rounding)};
51+
overflow |= next.flags.test(RealFlag::Overflow);
52+
auto added{sum.Add(next.value, rounding)};
53+
overflow |= added.flags.test(RealFlag::Overflow);
54+
correction = added.value.Subtract(sum, rounding)
55+
.value.Subtract(next.value, rounding)
56+
.value;
57+
sum = std::move(added.value);
58+
} else {
59+
auto added{sum.Add(x, rounding)};
60+
overflow |= added.flags.test(RealFlag::Overflow);
61+
sum = std::move(added.value);
62+
}
5763
}
5864
} else if constexpr (T::category == TypeCategory::Logical) {
5965
Expr<T> conjunctions{Fold(context,
@@ -80,17 +86,23 @@ static Expr<T> FoldDotProduct(
8086
Expr<T> products{
8187
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
8288
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
83-
Element correction{}; // Use Kahan summation for greater precision.
89+
[[maybe_unused]] Element correction{};
8490
const auto &rounding{context.targetCharacteristics().roundingMode()};
8591
for (const Element &x : cProducts.values()) {
86-
auto next{correction.Add(x, rounding)};
87-
overflow |= next.flags.test(RealFlag::Overflow);
88-
auto added{sum.Add(next.value, rounding)};
89-
overflow |= added.flags.test(RealFlag::Overflow);
90-
correction = added.value.Subtract(sum, rounding)
91-
.value.Subtract(next.value, rounding)
92-
.value;
93-
sum = std::move(added.value);
92+
if constexpr (useKahanSummation) {
93+
auto next{correction.Add(x, rounding)};
94+
overflow |= next.flags.test(RealFlag::Overflow);
95+
auto added{sum.Add(next.value, rounding)};
96+
overflow |= added.flags.test(RealFlag::Overflow);
97+
correction = added.value.Subtract(sum, rounding)
98+
.value.Subtract(next.value, rounding)
99+
.value;
100+
sum = std::move(added.value);
101+
} else {
102+
auto added{sum.Add(x, rounding)};
103+
overflow |= added.flags.test(RealFlag::Overflow);
104+
sum = std::move(added.value);
105+
}
94106
}
95107
}
96108
if (overflow) {

0 commit comments

Comments
 (0)