Skip to content

Commit ce7700e

Browse files
authored
[flang][runtime] Address PRODUCT numeric discrepancy, folding vs runtime (#90125)
Ensure that the runtime implementations of floating-point reductions use intermediate results of the same precision as the operands, so that results match those from constant folding. (SUM reduction uses Kahan summation in both cases.)
1 parent 1e82d50 commit ce7700e

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

flang/runtime/product.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(ProductInteger16)(
107107
CppTypeFor<TypeCategory::Real, 4> RTDEF(ProductReal4)(const Descriptor &x,
108108
const char *source, int line, int dim, const Descriptor *mask) {
109109
return GetTotalReduction<TypeCategory::Real, 4>(x, source, line, dim, mask,
110-
NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
110+
NonComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
111111
"PRODUCT");
112112
}
113113
CppTypeFor<TypeCategory::Real, 8> RTDEF(ProductReal8)(const Descriptor &x,
@@ -137,7 +137,7 @@ void RTDEF(CppProductComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
137137
const Descriptor &x, const char *source, int line, int dim,
138138
const Descriptor *mask) {
139139
result = GetTotalReduction<TypeCategory::Complex, 4>(x, source, line, dim,
140-
mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 8>>{x},
140+
mask, ComplexProductAccumulator<CppTypeFor<TypeCategory::Real, 4>>{x},
141141
"PRODUCT");
142142
}
143143
void RTDEF(CppProductComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
@@ -169,8 +169,8 @@ void RTDEF(CppProductComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
169169
void RTDEF(ProductDim)(Descriptor &result, const Descriptor &x, int dim,
170170
const char *source, int line, const Descriptor *mask) {
171171
TypedPartialNumericReduction<NonComplexProductAccumulator,
172-
NonComplexProductAccumulator, ComplexProductAccumulator>(
173-
result, x, dim, source, line, mask, "PRODUCT");
172+
NonComplexProductAccumulator, ComplexProductAccumulator,
173+
/*MIN_REAL_KIND=*/4>(result, x, dim, source, line, mask, "PRODUCT");
174174
}
175175

176176
RT_EXT_API_GROUP_END

flang/runtime/reduction-templates.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -240,11 +240,10 @@ inline RT_API_ATTRS void PartialIntegerReduction(Descriptor &result,
240240
kind, terminator, result, x, dim, mask, terminator, intrinsic);
241241
}
242242

243-
template <TypeCategory CAT, template <typename> class ACCUM>
243+
template <TypeCategory CAT, template <typename> class ACCUM, int MIN_KIND>
244244
struct PartialFloatingReductionHelper {
245245
template <int KIND> struct Functor {
246-
static constexpr int Intermediate{
247-
std::max(KIND, 8)}; // use at least "double" for intermediate results
246+
static constexpr int Intermediate{std::max(KIND, MIN_KIND)};
248247
RT_API_ATTRS void operator()(Descriptor &result, const Descriptor &x,
249248
int dim, const Descriptor *mask, Terminator &terminator,
250249
const char *intrinsic) const {
@@ -260,7 +259,7 @@ struct PartialFloatingReductionHelper {
260259

261260
template <template <typename> class INTEGER_ACCUM,
262261
template <typename> class REAL_ACCUM,
263-
template <typename> class COMPLEX_ACCUM>
262+
template <typename> class COMPLEX_ACCUM, int MIN_REAL_KIND>
264263
inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
265264
const Descriptor &x, int dim, const char *source, int line,
266265
const Descriptor *mask, const char *intrinsic) {
@@ -274,13 +273,13 @@ inline RT_API_ATTRS void TypedPartialNumericReduction(Descriptor &result,
274273
break;
275274
case TypeCategory::Real:
276275
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Real,
277-
REAL_ACCUM>::template Functor,
276+
REAL_ACCUM, MIN_REAL_KIND>::template Functor,
278277
void>(catKind->second, terminator, result, x, dim, mask, terminator,
279278
intrinsic);
280279
break;
281280
case TypeCategory::Complex:
282281
ApplyFloatingPointKind<PartialFloatingReductionHelper<TypeCategory::Complex,
283-
COMPLEX_ACCUM>::template Functor,
282+
COMPLEX_ACCUM, MIN_REAL_KIND>::template Functor,
284283
void>(catKind->second, terminator, result, x, dim, mask, terminator,
285284
intrinsic);
286285
break;

flang/runtime/sum.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ CppTypeFor<TypeCategory::Integer, 16> RTDEF(SumInteger16)(const Descriptor &x,
134134
CppTypeFor<TypeCategory::Real, 4> RTDEF(SumReal4)(const Descriptor &x,
135135
const char *source, int line, int dim, const Descriptor *mask) {
136136
return GetTotalReduction<TypeCategory::Real, 4>(
137-
x, source, line, dim, mask, RealSumAccumulator<double>{x}, "SUM");
137+
x, source, line, dim, mask, RealSumAccumulator<float>{x}, "SUM");
138138
}
139139
CppTypeFor<TypeCategory::Real, 8> RTDEF(SumReal8)(const Descriptor &x,
140140
const char *source, int line, int dim, const Descriptor *mask) {
@@ -160,7 +160,7 @@ void RTDEF(CppSumComplex4)(CppTypeFor<TypeCategory::Complex, 4> &result,
160160
const Descriptor &x, const char *source, int line, int dim,
161161
const Descriptor *mask) {
162162
result = GetTotalReduction<TypeCategory::Complex, 4>(
163-
x, source, line, dim, mask, ComplexSumAccumulator<double>{x}, "SUM");
163+
x, source, line, dim, mask, ComplexSumAccumulator<float>{x}, "SUM");
164164
}
165165
void RTDEF(CppSumComplex8)(CppTypeFor<TypeCategory::Complex, 8> &result,
166166
const Descriptor &x, const char *source, int line, int dim,
@@ -188,7 +188,8 @@ void RTDEF(CppSumComplex16)(CppTypeFor<TypeCategory::Complex, 16> &result,
188188
void RTDEF(SumDim)(Descriptor &result, const Descriptor &x, int dim,
189189
const char *source, int line, const Descriptor *mask) {
190190
TypedPartialNumericReduction<IntegerSumAccumulator, RealSumAccumulator,
191-
ComplexSumAccumulator>(result, x, dim, source, line, mask, "SUM");
191+
ComplexSumAccumulator, /*MIN_REAL_KIND=*/4>(
192+
result, x, dim, source, line, mask, "SUM");
192193
}
193194

194195
RT_EXT_API_GROUP_END

0 commit comments

Comments
 (0)