Skip to content

[flang] Adjust transformational folding to match runtime #90132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flang/lib/Evaluate/fold-implementation.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@

namespace Fortran::evaluate {

// Don't use Kahan extended precision summation any more when folding
// transformational intrinsic functions other than SUM, since it is
// not used in the runtime implementations of those functions and we
// want results to match.
static constexpr bool useKahanSummation{false};

// Utilities
template <typename T> class Folder {
public:
Expand Down
27 changes: 17 additions & 10 deletions flang/lib/Evaluate/fold-matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,25 @@ static Expr<T> FoldMatmul(FoldingContext &context, FunctionRef<T> &&funcRef) {
Element bElt{mb->At(bAt)};
if constexpr (T::category == TypeCategory::Real ||
T::category == TypeCategory::Complex) {
// Kahan summation
auto product{aElt.Multiply(bElt, rounding)};
auto product{aElt.Multiply(bElt)};
overflow |= product.flags.test(RealFlag::Overflow);
auto next{correction.Add(product.value, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
if constexpr (useKahanSummation) {
auto next{correction.Add(product.value, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(product.value)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
} else if constexpr (T::category == TypeCategory::Integer) {
// Don't use Kahan summation in numeric MATMUL folding;
// the runtime doesn't use it, and results should match.
auto product{aElt.MultiplySigned(bElt)};
overflow |= product.SignedMultiplicationOverflowed();
auto added{sum.AddSigned(product.lower)};
Expand Down
33 changes: 18 additions & 15 deletions flang/lib/Evaluate/fold-real.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ template <int KIND> class Norm2Accumulator {
: array_{array}, maxAbs_{maxAbs}, rounding_{rounding} {};
void operator()(
Scalar<T> &element, const ConstantSubscripts &at, bool /*first*/) {
// Kahan summation of scaled elements:
// Summation of scaled elements:
// Naively,
// NORM2(A(:)) = SQRT(SUM(A(:)**2))
// For any T > 0, we have mathematically
Expand All @@ -76,24 +76,27 @@ template <int KIND> class Norm2Accumulator {
auto item{array_.At(at)};
auto scaled{item.Divide(scale).value};
auto square{scaled.Multiply(scaled).value};
auto next{square.Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
if constexpr (useKahanSummation) {
auto next{square.Add(correction_, rounding_)};
overflow_ |= next.flags.test(RealFlag::Overflow);
auto sum{element.Add(next.value, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
correction_ = sum.value.Subtract(element, rounding_)
.value.Subtract(next.value, rounding_)
.value;
element = sum.value;
} else {
auto sum{element.Add(square, rounding_)};
overflow_ |= sum.flags.test(RealFlag::Overflow);
element = sum.value;
}
}
}
bool overflow() const { return overflow_; }
void Done(Scalar<T> &result) {
// result+correction == SUM((data(:)/maxAbs)**2)
// result = maxAbs * SQRT(result+correction)
auto corrected{result.Add(correction_, rounding_)};
overflow_ |= corrected.flags.test(RealFlag::Overflow);
correction_ = Scalar<T>{};
auto root{corrected.value.SQRT().value};
// incoming result = SUM((data(:)/maxAbs)**2)
// outgoing result = maxAbs * SQRT(result)
auto root{result.SQRT().value};
auto product{root.Multiply(maxAbs_.At(maxAbsAt_))};
maxAbs_.IncrementSubscripts(maxAbsAt_);
overflow_ |= product.flags.test(RealFlag::Overflow);
Expand Down
48 changes: 30 additions & 18 deletions flang/lib/Evaluate/fold-reduction.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{Fold(
context, Expr<T>{std::move(conjgA)} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
Element correction{}; // Use Kahan summation for greater precision.
[[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
if constexpr (useKahanSummation) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
}
} else if constexpr (T::category == TypeCategory::Logical) {
Expr<T> conjunctions{Fold(context,
Expand All @@ -80,17 +86,23 @@ static Expr<T> FoldDotProduct(
Expr<T> products{
Fold(context, Expr<T>{Constant<T>{*va}} * Expr<T>{Constant<T>{*vb}})};
Constant<T> &cProducts{DEREF(UnwrapConstantValue<T>(products))};
Element correction{}; // Use Kahan summation for greater precision.
[[maybe_unused]] Element correction{};
const auto &rounding{context.targetCharacteristics().roundingMode()};
for (const Element &x : cProducts.values()) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
if constexpr (useKahanSummation) {
auto next{correction.Add(x, rounding)};
overflow |= next.flags.test(RealFlag::Overflow);
auto added{sum.Add(next.value, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
correction = added.value.Subtract(sum, rounding)
.value.Subtract(next.value, rounding)
.value;
sum = std::move(added.value);
} else {
auto added{sum.Add(x, rounding)};
overflow |= added.flags.test(RealFlag::Overflow);
sum = std::move(added.value);
}
}
}
if (overflow) {
Expand Down