Skip to content

[flang][runtime] Address PRODUCT numeric discrepancy, folding vs runtime #90125

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions flang/runtime/product.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(ProductInteger16)(
CppTypeFor<TypeCategory::Real, 4> RTDEF(ProductReal4)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, mask,
NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
"PRODUCT");
}
CppTypeFor<TypeCategory::Real, 8> RTDEF(ProductReal8)(const Descriptor &x,
Expand Down Expand Up @@ -137,7 +137,7 @@ void RTDEF(CppProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
result = GetTotalReduction<TypeCategory::Complex, 4>(x, source, line, dim,
mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
"PRODUCT");
}
void RTDEF(CppProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
Expand Down Expand Up @@ -169,8 +169,8 @@ void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim,
const char *source, int line, const Descriptor *mask) {
TypedPartialNumericReduction<NonComplexProductAccumulator,
NonComplexProductAccumulator, ComplexProductAccumulator>(
result, x, dim, source, line, mask, "PRODUCT");
NonComplexProductAccumulator, ComplexProductAccumulator,
/*MIN_REAL_KIND=*/4>(result, x, dim, source, line, mask, "PRODUCT");
}

RT_EXT_API_GROUP_END
Expand Down
11 changes: 5 additions & 6 deletions flang/runtime/reduction-templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -240,11 +240,10 @@ inline RT_API_ATTRS void PartialIntegerReduction(Descriptor &result,
kind, terminator, result, x, dim, mask, terminator, intrinsic);
}

template <TypeCategory CAT, template <typename> class ACCUM>
template <TypeCategory CAT, template <typename> class ACCUM, int MIN_KIND>
struct PartialFloatingReductionHelper {
template <int KIND> struct Functor {
static constexpr int Intermediate{
std::max(KIND, 8)}; // use at least "double" for intermediate results
static constexpr int Intermediate{std::max(KIND, MIN_KIND)};
RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
int dim, const Descriptor *mask, Terminator &terminator,
const char *intrinsic) const {
Expand All @@ -260,7 +259,7 @@ struct PartialFloatingReductionHelper {

template <template <typename> class INTEGER_ACCUM,
template <typename> class REAL_ACCUM,
template <typename> class COMPLEX_ACCUM>
template <typename> class COMPLEX_ACCUM, int MIN_REAL_KIND>
inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
const Descriptor &x, int dim, const char *source, int line,
const Descriptor *mask, const char *intrinsic) {
Expand All @@ -274,13 +273,13 @@ inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
break;
case TypeCategory::Real:
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Real,
REAL_ACCUM>::template Functor,
REAL_ACCUM, MIN_REAL_KIND>::template Functor,
void>(catKind->second, terminator, result, x, dim, mask, terminator,
intrinsic);
break;
case TypeCategory::Complex:
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Complex,
COMPLEX_ACCUM>::template Functor,
COMPLEX_ACCUM, MIN_REAL_KIND>::template Functor,
void>(catKind->second, terminator, result, x, dim, mask, terminator,
intrinsic);
break;
Expand Down
7 changes: 4 additions & 3 deletions flang/runtime/sum.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(SumInteger16)(const Descriptor &x,
CppTypeFor<TypeCategory::Real, 4> RTDEF(SumReal4)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
return GetTotalReduction<TypeCategory::Real, 4>(
x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
x, source, line, dim, mask, RealSumAccumulator<float>{x}, "SUM");
}
CppTypeFor<TypeCategory::Real, 8> RTDEF(SumReal8)(const Descriptor &x,
const char *source, int line, int dim, const Descriptor *mask) {
Expand All @@ -160,7 +160,7 @@ void RTDEF(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
const Descriptor &x, const char *source, int line, int dim,
const Descriptor *mask) {
result = GetTotalReduction<TypeCategory::Complex, 4>(
x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
x, source, line, dim, mask, ComplexSumAccumulator<float>{x}, "SUM");
}
void RTDEF(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
const Descriptor &x, const char *source, int line, int dim,
Expand Down Expand Up @@ -188,7 +188,8 @@ void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
const char *source, int line, const Descriptor *mask) {
TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
ComplexSumAccumulator>(result, x, dim, source, line, mask, "SUM");
ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
result, x, dim, source, line, mask, "SUM");
}

RT_EXT_API_GROUP_END
Expand Down