Skip to content

Commit a00946f

Browse files
authored
[flang] Simplify hlfir.sum total reductions. (#119482)
I am trying to switch to keeping the reduction value in a temporary scalar location so that I can use hlfir::genLoopNest easily. This also allows using omp.loop_nest with worksharing for OpenMP.
1 parent af5d3af commit a00946f

File tree

4 files changed

+323
-182
lines changed

4 files changed

+323
-182
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,10 @@ struct LoopNest {
366366
/// Generate a fir.do_loop nest looping from 1 to extents[i].
367367
/// \p isUnordered specifies whether the loops in the loop nest
368368
/// are unordered.
369+
///
370+
/// NOTE: genLoopNestWithReductions() should be used in favor
371+
/// of this method, though, it cannot generate OpenMP workshare
372+
/// loop constructs currently.
369373
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
370374
mlir::ValueRange extents, bool isUnordered = false,
371375
bool emitWorkshareLoop = false);
@@ -376,6 +380,37 @@ inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
376380
isUnordered, emitWorkshareLoop);
377381
}
378382

383+
/// The type of a callback that generates the body of a reduction
384+
/// loop nest. It takes a location and a builder, as usual.
385+
/// In addition, the first set of values are the values of the loops'
386+
/// induction variables. The second set of values are the values
387+
/// of the reductions on entry to the innermost loop.
388+
/// The callback must return the updated values of the reductions.
389+
using ReductionLoopBodyGenerator = std::function<llvm::SmallVector<mlir::Value>(
390+
mlir::Location, fir::FirOpBuilder &, mlir::ValueRange, mlir::ValueRange)>;
391+
392+
/// Generate a loop nest loopong from 1 to \p extents[i] and reducing
393+
/// a set of values.
394+
/// \p isUnordered specifies whether the loops in the loop nest
395+
/// are unordered.
396+
/// \p reductionInits are the initial values of the reductions
397+
/// on entry to the outermost loop.
398+
/// \p genBody callback is repsonsible for generating the code
399+
/// that updates the reduction values in the innermost loop.
400+
///
401+
/// NOTE: the implementation of this function may decide
402+
/// to perform the reductions on SSA or in memory.
403+
/// In the latter case, this function is responsible for
404+
/// allocating/loading/storing the reduction variables,
405+
/// and making sure they have proper data sharing attributes
406+
/// in case any parallel constructs are present around the point
407+
/// of the loop nest insertion, or if the function decides
408+
/// to use any worksharing loop constructs for the loop nest.
409+
llvm::SmallVector<mlir::Value> genLoopNestWithReductions(
410+
mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
411+
mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
412+
bool isUnordered = false);
413+
379414
/// Inline the body of an hlfir.elemental at the current insertion point
380415
/// given a list of one based indices. This generates the computation
381416
/// of one element of the elemental expression. Return the YieldElementOp

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,56 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
910910
return loopNest;
911911
}
912912

913+
llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
914+
mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
915+
mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
916+
bool isUnordered) {
917+
assert(!extents.empty() && "must have at least one extent");
918+
// Build loop nest from column to row.
919+
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
920+
mlir::Type indexType = builder.getIndexType();
921+
unsigned dim = extents.size() - 1;
922+
fir::DoLoopOp outerLoop = nullptr;
923+
fir::DoLoopOp parentLoop = nullptr;
924+
llvm::SmallVector<mlir::Value> oneBasedIndices;
925+
oneBasedIndices.resize(dim + 1);
926+
for (auto extent : llvm::reverse(extents)) {
927+
auto ub = builder.createConvert(loc, indexType, extent);
928+
929+
// The outermost loop takes reductionInits as the initial
930+
// values of its iter-args.
931+
// A child loop takes its iter-args from the region iter-args
932+
// of its parent loop.
933+
fir::DoLoopOp doLoop;
934+
if (!parentLoop) {
935+
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
936+
/*finalCountValue=*/false,
937+
reductionInits);
938+
} else {
939+
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
940+
/*finalCountValue=*/false,
941+
parentLoop.getRegionIterArgs());
942+
// Return the results of the child loop from its parent loop.
943+
builder.create<fir::ResultOp>(loc, doLoop.getResults());
944+
}
945+
946+
builder.setInsertionPointToStart(doLoop.getBody());
947+
// Reverse the indices so they are in column-major order.
948+
oneBasedIndices[dim--] = doLoop.getInductionVar();
949+
if (!outerLoop)
950+
outerLoop = doLoop;
951+
parentLoop = doLoop;
952+
}
953+
954+
llvm::SmallVector<mlir::Value> reductionValues;
955+
reductionValues =
956+
genBody(loc, builder, oneBasedIndices, parentLoop.getRegionIterArgs());
957+
builder.setInsertionPointToEnd(parentLoop.getBody());
958+
builder.create<fir::ResultOp>(loc, reductionValues);
959+
builder.setInsertionPointAfter(outerLoop);
960+
return outerLoop->getResults();
961+
}
962+
913963
static fir::ExtendedValue translateVariableToExtendedValue(
914964
mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity variable,
915965
bool forceHlfirBase = false, bool contiguousHint = false) {

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 128 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -106,34 +106,43 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
106106
mlir::PatternRewriter &rewriter) const override {
107107
mlir::Location loc = sum.getLoc();
108108
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
109-
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
110-
assert(expr && "expected an expression type for the result of hlfir.sum");
111-
mlir::Type elementType = expr.getElementType();
109+
mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
112110
hlfir::Entity array = hlfir::Entity{sum.getArray()};
113111
mlir::Value mask = sum.getMask();
114112
mlir::Value dim = sum.getDim();
115-
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
113+
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
114+
int64_t dimVal =
115+
isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
116116
mlir::Value resultShape, dimExtent;
117-
std::tie(resultShape, dimExtent) =
118-
genResultShape(loc, builder, array, dimVal);
117+
llvm::SmallVector<mlir::Value> arrayExtents;
118+
if (isTotalReduction)
119+
arrayExtents = genArrayExtents(loc, builder, array);
120+
else
121+
std::tie(resultShape, dimExtent) =
122+
genResultShapeForPartialReduction(loc, builder, array, dimVal);
123+
124+
// If the mask is present and is a scalar, then we'd better load its value
125+
// outside of the reduction loop making the loop unswitching easier.
126+
mlir::Value isPresentPred, maskValue;
127+
if (mask) {
128+
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
129+
// MASK represented by a box might be dynamically optional,
130+
// so we have to check for its presence before accessing it.
131+
isPresentPred =
132+
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
133+
}
134+
135+
if (hlfir::Entity{mask}.isScalar())
136+
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
137+
}
119138

120139
auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
121140
mlir::ValueRange inputIndices) -> hlfir::Entity {
122141
// Loop over all indices in the DIM dimension, and reduce all values.
123-
// We do not need to create the reduction loop always: if we can
124-
// slice the input array given the inputIndices, then we can
125-
// just apply a new SUM operation (total reduction) to the slice.
126-
// For the time being, generate the explicit loop because the slicing
127-
// requires generating an elemental operation for the input array
128-
// (and the mask, if present).
129-
// TODO: produce the slices and new SUM after adding a pattern
130-
// for expanding total reduction SUM case.
131-
mlir::Type indexType = builder.getIndexType();
132-
auto one = builder.createIntegerConstant(loc, indexType, 1);
133-
auto ub = builder.createConvert(loc, indexType, dimExtent);
142+
// If DIM is not present, do total reduction.
134143

135144
// Initial value for the reduction.
136-
mlir::Value initValue = genInitValue(loc, builder, elementType);
145+
mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);
137146

138147
// The reduction loop may be unordered if FastMathFlags::reassoc
139148
// transformations are allowed. The integer reduction is always
@@ -142,79 +151,83 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
142151
static_cast<bool>(sum.getFastmath() &
143152
mlir::arith::FastMathFlags::reassoc);
144153

145-
// If the mask is present and is a scalar, then we'd better load its value
146-
// outside of the reduction loop making the loop unswitching easier.
147-
// Maybe it is worth hoisting it from the elemental operation as well.
148-
mlir::Value isPresentPred, maskValue;
149-
if (mask) {
150-
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
151-
// MASK represented by a box might be dynamically optional,
152-
// so we have to check for its presence before accessing it.
153-
isPresentPred =
154-
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
154+
llvm::SmallVector<mlir::Value> extents;
155+
if (isTotalReduction)
156+
extents = arrayExtents;
157+
else
158+
extents.push_back(
159+
builder.createConvert(loc, builder.getIndexType(), dimExtent));
160+
161+
auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
162+
mlir::ValueRange oneBasedIndices,
163+
mlir::ValueRange reductionArgs)
164+
-> llvm::SmallVector<mlir::Value, 1> {
165+
// Generate the reduction loop-nest body.
166+
// The initial reduction value in the innermost loop
167+
// is passed via reductionArgs[0].
168+
llvm::SmallVector<mlir::Value> indices;
169+
if (isTotalReduction) {
170+
indices = oneBasedIndices;
171+
} else {
172+
indices = inputIndices;
173+
indices.insert(indices.begin() + dimVal - 1, oneBasedIndices[0]);
155174
}
156175

157-
if (hlfir::Entity{mask}.isScalar())
158-
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
159-
}
176+
mlir::Value reductionValue = reductionArgs[0];
177+
fir::IfOp ifOp;
178+
if (mask) {
179+
// Make the reduction value update conditional on the value
180+
// of the mask.
181+
if (!maskValue) {
182+
// If the mask is an array, use the elemental and the loop indices
183+
// to address the proper mask element.
184+
maskValue =
185+
genMaskValue(loc, builder, mask, isPresentPred, indices);
186+
}
187+
mlir::Value isUnmasked = builder.create<fir::ConvertOp>(
188+
loc, builder.getI1Type(), maskValue);
189+
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
190+
/*withElseRegion=*/true);
191+
// In the 'else' block return the current reduction value.
192+
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
193+
builder.create<fir::ResultOp>(loc, reductionValue);
194+
195+
// In the 'then' block do the actual addition.
196+
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
197+
}
160198

161-
// NOTE: the outer elemental operation may be lowered into
162-
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
163-
// loop may appear disjoint from the workshare loop nest.
164-
// Moreover, the inner loop is not strictly nested (due to the reduction
165-
// starting value initialization), and the above omp dialect operations
166-
// cannot produce results.
167-
// It is unclear what we should do about it yet.
168-
auto doLoop = builder.create<fir::DoLoopOp>(
169-
loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
170-
mlir::ValueRange{initValue});
171-
172-
// Address the input array using the reduction loop's IV
173-
// for the DIM dimension.
174-
mlir::Value iv = doLoop.getInductionVar();
175-
llvm::SmallVector<mlir::Value> indices{inputIndices};
176-
indices.insert(indices.begin() + dimVal - 1, iv);
177-
178-
mlir::OpBuilder::InsertionGuard guard(builder);
179-
builder.setInsertionPointToStart(doLoop.getBody());
180-
mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
181-
fir::IfOp ifOp;
182-
if (mask) {
183-
// Make the reduction value update conditional on the value
184-
// of the mask.
185-
if (!maskValue) {
186-
// If the mask is an array, use the elemental and the loop indices
187-
// to address the proper mask element.
188-
maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices);
199+
hlfir::Entity element =
200+
hlfir::getElementAt(loc, builder, array, indices);
201+
hlfir::Entity elementValue =
202+
hlfir::loadTrivialScalar(loc, builder, element);
203+
// NOTE: we can use "Kahan summation" same way as the runtime
204+
// (e.g. when fast-math is not allowed), but let's start with
205+
// the simple version.
206+
reductionValue =
207+
genScalarAdd(loc, builder, reductionValue, elementValue);
208+
209+
if (ifOp) {
210+
builder.create<fir::ResultOp>(loc, reductionValue);
211+
builder.setInsertionPointAfter(ifOp);
212+
reductionValue = ifOp.getResult(0);
189213
}
190-
mlir::Value isUnmasked =
191-
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
192-
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
193-
/*withElseRegion=*/true);
194-
// In the 'else' block return the current reduction value.
195-
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
196-
builder.create<fir::ResultOp>(loc, reductionValue);
197-
198-
// In the 'then' block do the actual addition.
199-
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
200-
}
201214

202-
hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
203-
hlfir::Entity elementValue =
204-
hlfir::loadTrivialScalar(loc, builder, element);
205-
// NOTE: we can use "Kahan summation" same way as the runtime
206-
// (e.g. when fast-math is not allowed), but let's start with
207-
// the simple version.
208-
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
209-
builder.create<fir::ResultOp>(loc, reductionValue);
210-
211-
if (ifOp) {
212-
builder.setInsertionPointAfter(ifOp);
213-
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
214-
}
215+
return {reductionValue};
216+
};
215217

216-
return hlfir::Entity{doLoop.getResult(0)};
218+
llvm::SmallVector<mlir::Value, 1> reductionFinalValues =
219+
hlfir::genLoopNestWithReductions(loc, builder, extents,
220+
{reductionInitValue}, genBody,
221+
isUnordered);
222+
return hlfir::Entity{reductionFinalValues[0]};
217223
};
224+
225+
if (isTotalReduction) {
226+
hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
227+
rewriter.replaceOp(sum, result);
228+
return mlir::success();
229+
}
230+
218231
hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
219232
loc, builder, elementType, resultShape, {}, genKernel,
220233
/*isUnordered=*/true, /*polymorphicMold=*/nullptr,
@@ -230,20 +243,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
230243
}
231244

232245
private:
246+
static llvm::SmallVector<mlir::Value>
247+
genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
248+
hlfir::Entity array) {
249+
mlir::Value inShape = hlfir::genShape(loc, builder, array);
250+
llvm::SmallVector<mlir::Value> inExtents =
251+
hlfir::getExplicitExtentsFromShape(inShape, builder);
252+
if (inShape.getUses().empty())
253+
inShape.getDefiningOp()->erase();
254+
return inExtents;
255+
}
256+
233257
// Return fir.shape specifying the shape of the result
234258
// of a SUM reduction with DIM=dimVal. The second return value
235259
// is the extent of the DIM dimension.
236260
static std::tuple<mlir::Value, mlir::Value>
237-
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
238-
hlfir::Entity array, int64_t dimVal) {
239-
mlir::Value inShape = hlfir::genShape(loc, builder, array);
261+
genResultShapeForPartialReduction(mlir::Location loc,
262+
fir::FirOpBuilder &builder,
263+
hlfir::Entity array, int64_t dimVal) {
240264
llvm::SmallVector<mlir::Value> inExtents =
241-
hlfir::getExplicitExtentsFromShape(inShape, builder);
265+
genArrayExtents(loc, builder, array);
242266
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
243267
"DIM must be present and a positive constant not exceeding "
244268
"the array's rank");
245-
if (inShape.getUses().empty())
246-
inShape.getDefiningOp()->erase();
247269

248270
mlir::Value dimExtent = inExtents[dimVal - 1];
249271
inExtents.erase(inExtents.begin() + dimVal - 1);
@@ -459,22 +481,22 @@ class SimplifyHLFIRIntrinsics
459481
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
460482
if (!simplifySum)
461483
return true;
462-
if (mlir::Value dim = sum.getDim()) {
463-
if (auto dimVal = fir::getIntIfConstant(dim)) {
464-
if (!fir::isa_trivial(sum.getType())) {
465-
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
466-
// It is only legal when X is 1, and it should probably be
467-
// canonicalized into SUM(a).
468-
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
469-
hlfir::getFortranElementOrSequenceType(
470-
sum.getArray().getType()));
471-
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
472-
// Ignore SUMs with illegal DIM values.
473-
// They may appear in dead code,
474-
// and they do not have to be converted.
475-
return false;
476-
}
477-
}
484+
485+
// Always inline total reductions.
486+
if (hlfir::Entity{sum}.getRank() == 0)
487+
return false;
488+
mlir::Value dim = sum.getDim();
489+
if (!dim)
490+
return false;
491+
492+
if (auto dimVal = fir::getIntIfConstant(dim)) {
493+
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
494+
hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
495+
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
496+
// Ignore SUMs with illegal DIM values.
497+
// They may appear in dead code,
498+
// and they do not have to be converted.
499+
return false;
478500
}
479501
}
480502
return true;

0 commit comments

Comments
 (0)