Skip to content

[flang] Simplify hlfir.sum total reductions. #119482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,6 +366,10 @@ struct LoopNest {
/// Generate a fir.do_loop nest looping from 1 to extents[i].
/// \p isUnordered specifies whether the loops in the loop nest
/// are unordered.
///
/// NOTE: genLoopNestWithReductions() should be used in favor
/// of this method, though, it cannot generate OpenMP workshare
/// loop constructs currently.
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange extents, bool isUnordered = false,
bool emitWorkshareLoop = false);
Expand All @@ -376,6 +380,37 @@ inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
isUnordered, emitWorkshareLoop);
}

/// The type of a callback that generates the body of a reduction
/// loop nest. It takes a location and a builder, as usual.
/// In addition, the first set of values are the values of the loops'
/// induction variables. The second set of values are the values
/// of the reductions on entry to the innermost loop.
/// The callback must return the updated values of the reductions.
using ReductionLoopBodyGenerator = std::function<llvm::SmallVector<mlir::Value>(
mlir::Location, fir::FirOpBuilder &, mlir::ValueRange, mlir::ValueRange)>;

/// Generate a loop nest loopong from 1 to \p extents[i] and reducing
/// a set of values.
/// \p isUnordered specifies whether the loops in the loop nest
/// are unordered.
/// \p reductionInits are the initial values of the reductions
/// on entry to the outermost loop.
/// \p genBody callback is repsonsible for generating the code
/// that updates the reduction values in the innermost loop.
///
/// NOTE: the implementation of this function may decide
/// to perform the reductions on SSA or in memory.
/// In the latter case, this function is responsible for
/// allocating/loading/storing the reduction variables,
/// and making sure they have proper data sharing attributes
/// in case any parallel constructs are present around the point
/// of the loop nest insertion, or if the function decides
/// to use any worksharing loop constructs for the loop nest.
llvm::SmallVector<mlir::Value> genLoopNestWithReductions(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
bool isUnordered = false);

/// Inline the body of an hlfir.elemental at the current insertion point
/// given a list of one based indices. This generates the computation
/// of one element of the elemental expression. Return the YieldElementOp
Expand Down
50 changes: 50 additions & 0 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,56 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
return loopNest;
}

llvm::SmallVector<mlir::Value> hlfir::genLoopNestWithReductions(
mlir::Location loc, fir::FirOpBuilder &builder, mlir::ValueRange extents,
mlir::ValueRange reductionInits, const ReductionLoopBodyGenerator &genBody,
bool isUnordered) {
assert(!extents.empty() && "must have at least one extent");
// Build loop nest from column to row.
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
mlir::Type indexType = builder.getIndexType();
unsigned dim = extents.size() - 1;
fir::DoLoopOp outerLoop = nullptr;
fir::DoLoopOp parentLoop = nullptr;
llvm::SmallVector<mlir::Value> oneBasedIndices;
oneBasedIndices.resize(dim + 1);
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);

// The outermost loop takes reductionInits as the initial
// values of its iter-args.
// A child loop takes its iter-args from the region iter-args
// of its parent loop.
fir::DoLoopOp doLoop;
if (!parentLoop) {
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
/*finalCountValue=*/false,
reductionInits);
} else {
doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered,
/*finalCountValue=*/false,
parentLoop.getRegionIterArgs());
// Return the results of the child loop from its parent loop.
builder.create<fir::ResultOp>(loc, doLoop.getResults());
}

builder.setInsertionPointToStart(doLoop.getBody());
// Reverse the indices so they are in column-major order.
oneBasedIndices[dim--] = doLoop.getInductionVar();
if (!outerLoop)
outerLoop = doLoop;
parentLoop = doLoop;
}

llvm::SmallVector<mlir::Value> reductionValues;
reductionValues =
genBody(loc, builder, oneBasedIndices, parentLoop.getRegionIterArgs());
builder.setInsertionPointToEnd(parentLoop.getBody());
builder.create<fir::ResultOp>(loc, reductionValues);
builder.setInsertionPointAfter(outerLoop);
return outerLoop->getResults();
}

static fir::ExtendedValue translateVariableToExtendedValue(
mlir::Location loc, fir::FirOpBuilder &builder, hlfir::Entity variable,
bool forceHlfirBase = false, bool contiguousHint = false) {
Expand Down
234 changes: 128 additions & 106 deletions flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,34 +105,43 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
mlir::PatternRewriter &rewriter) const override {
mlir::Location loc = sum.getLoc();
fir::FirOpBuilder builder{rewriter, sum.getOperation()};
hlfir::ExprType expr = mlir::dyn_cast<hlfir::ExprType>(sum.getType());
assert(expr && "expected an expression type for the result of hlfir.sum");
mlir::Type elementType = expr.getElementType();
mlir::Type elementType = hlfir::getFortranElementType(sum.getType());
hlfir::Entity array = hlfir::Entity{sum.getArray()};
mlir::Value mask = sum.getMask();
mlir::Value dim = sum.getDim();
int64_t dimVal = fir::getIntIfConstant(dim).value_or(0);
bool isTotalReduction = hlfir::Entity{sum}.getRank() == 0;
int64_t dimVal =
isTotalReduction ? 0 : fir::getIntIfConstant(dim).value_or(0);
mlir::Value resultShape, dimExtent;
std::tie(resultShape, dimExtent) =
genResultShape(loc, builder, array, dimVal);
llvm::SmallVector<mlir::Value> arrayExtents;
if (isTotalReduction)
arrayExtents = genArrayExtents(loc, builder, array);
else
std::tie(resultShape, dimExtent) =
genResultShapeForPartialReduction(loc, builder, array, dimVal);

// If the mask is present and is a scalar, then we'd better load its value
// outside of the reduction loop making the loop unswitching easier.
mlir::Value isPresentPred, maskValue;
if (mask) {
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
// MASK represented by a box might be dynamically optional,
// so we have to check for its presence before accessing it.
isPresentPred =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
}

if (hlfir::Entity{mask}.isScalar())
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
}

auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange inputIndices) -> hlfir::Entity {
// Loop over all indices in the DIM dimension, and reduce all values.
// We do not need to create the reduction loop always: if we can
// slice the input array given the inputIndices, then we can
// just apply a new SUM operation (total reduction) to the slice.
// For the time being, generate the explicit loop because the slicing
// requires generating an elemental operation for the input array
// (and the mask, if present).
// TODO: produce the slices and new SUM after adding a pattern
// for expanding total reduction SUM case.
mlir::Type indexType = builder.getIndexType();
auto one = builder.createIntegerConstant(loc, indexType, 1);
auto ub = builder.createConvert(loc, indexType, dimExtent);
// If DIM is not present, do total reduction.

// Initial value for the reduction.
mlir::Value initValue = genInitValue(loc, builder, elementType);
mlir::Value reductionInitValue = genInitValue(loc, builder, elementType);

// The reduction loop may be unordered if FastMathFlags::reassoc
// transformations are allowed. The integer reduction is always
Expand All @@ -141,79 +150,83 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
static_cast<bool>(sum.getFastmath() &
mlir::arith::FastMathFlags::reassoc);

// If the mask is present and is a scalar, then we'd better load its value
// outside of the reduction loop making the loop unswitching easier.
// Maybe it is worth hoisting it from the elemental operation as well.
mlir::Value isPresentPred, maskValue;
if (mask) {
if (mlir::isa<fir::BaseBoxType>(mask.getType())) {
// MASK represented by a box might be dynamically optional,
// so we have to check for its presence before accessing it.
isPresentPred =
builder.create<fir::IsPresentOp>(loc, builder.getI1Type(), mask);
llvm::SmallVector<mlir::Value> extents;
if (isTotalReduction)
extents = arrayExtents;
else
extents.push_back(
builder.createConvert(loc, builder.getIndexType(), dimExtent));

auto genBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange oneBasedIndices,
mlir::ValueRange reductionArgs)
-> llvm::SmallVector<mlir::Value, 1> {
// Generate the reduction loop-nest body.
// The initial reduction value in the innermost loop
// is passed via reductionArgs[0].
llvm::SmallVector<mlir::Value> indices;
if (isTotalReduction) {
indices = oneBasedIndices;
} else {
indices = inputIndices;
indices.insert(indices.begin() + dimVal - 1, oneBasedIndices[0]);
}

if (hlfir::Entity{mask}.isScalar())
maskValue = genMaskValue(loc, builder, mask, isPresentPred, {});
}
mlir::Value reductionValue = reductionArgs[0];
fir::IfOp ifOp;
if (mask) {
// Make the reduction value update conditional on the value
// of the mask.
if (!maskValue) {
// If the mask is an array, use the elemental and the loop indices
// to address the proper mask element.
maskValue =
genMaskValue(loc, builder, mask, isPresentPred, indices);
}
mlir::Value isUnmasked = builder.create<fir::ConvertOp>(
loc, builder.getI1Type(), maskValue);
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
/*withElseRegion=*/true);
// In the 'else' block return the current reduction value.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reductionValue);

// In the 'then' block do the actual addition.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}

// NOTE: the outer elemental operation may be lowered into
// omp.workshare.loop_wrapper/omp.loop_nest later, so the reduction
// loop may appear disjoint from the workshare loop nest.
// Moreover, the inner loop is not strictly nested (due to the reduction
// starting value initialization), and the above omp dialect operations
// cannot produce results.
// It is unclear what we should do about it yet.
auto doLoop = builder.create<fir::DoLoopOp>(
loc, one, ub, one, isUnordered, /*finalCountValue=*/false,
mlir::ValueRange{initValue});

// Address the input array using the reduction loop's IV
// for the DIM dimension.
mlir::Value iv = doLoop.getInductionVar();
llvm::SmallVector<mlir::Value> indices{inputIndices};
indices.insert(indices.begin() + dimVal - 1, iv);

mlir::OpBuilder::InsertionGuard guard(builder);
builder.setInsertionPointToStart(doLoop.getBody());
mlir::Value reductionValue = doLoop.getRegionIterArgs()[0];
fir::IfOp ifOp;
if (mask) {
// Make the reduction value update conditional on the value
// of the mask.
if (!maskValue) {
// If the mask is an array, use the elemental and the loop indices
// to address the proper mask element.
maskValue = genMaskValue(loc, builder, mask, isPresentPred, indices);
hlfir::Entity element =
hlfir::getElementAt(loc, builder, array, indices);
hlfir::Entity elementValue =
hlfir::loadTrivialScalar(loc, builder, element);
// NOTE: we can use "Kahan summation" same way as the runtime
// (e.g. when fast-math is not allowed), but let's start with
// the simple version.
reductionValue =
genScalarAdd(loc, builder, reductionValue, elementValue);

if (ifOp) {
builder.create<fir::ResultOp>(loc, reductionValue);
builder.setInsertionPointAfter(ifOp);
reductionValue = ifOp.getResult(0);
}
mlir::Value isUnmasked =
builder.create<fir::ConvertOp>(loc, builder.getI1Type(), maskValue);
ifOp = builder.create<fir::IfOp>(loc, elementType, isUnmasked,
/*withElseRegion=*/true);
// In the 'else' block return the current reduction value.
builder.setInsertionPointToStart(&ifOp.getElseRegion().front());
builder.create<fir::ResultOp>(loc, reductionValue);

// In the 'then' block do the actual addition.
builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
}

hlfir::Entity element = hlfir::getElementAt(loc, builder, array, indices);
hlfir::Entity elementValue =
hlfir::loadTrivialScalar(loc, builder, element);
// NOTE: we can use "Kahan summation" same way as the runtime
// (e.g. when fast-math is not allowed), but let's start with
// the simple version.
reductionValue = genScalarAdd(loc, builder, reductionValue, elementValue);
builder.create<fir::ResultOp>(loc, reductionValue);

if (ifOp) {
builder.setInsertionPointAfter(ifOp);
builder.create<fir::ResultOp>(loc, ifOp.getResult(0));
}
return {reductionValue};
};

return hlfir::Entity{doLoop.getResult(0)};
llvm::SmallVector<mlir::Value, 1> reductionFinalValues =
hlfir::genLoopNestWithReductions(loc, builder, extents,
{reductionInitValue}, genBody,
isUnordered);
return hlfir::Entity{reductionFinalValues[0]};
};

if (isTotalReduction) {
hlfir::Entity result = genKernel(loc, builder, mlir::ValueRange{});
rewriter.replaceOp(sum, result);
return mlir::success();
}

hlfir::ElementalOp elementalOp = hlfir::genElementalOp(
loc, builder, elementType, resultShape, {}, genKernel,
/*isUnordered=*/true, /*polymorphicMold=*/nullptr,
Expand All @@ -229,20 +242,29 @@ class SumAsElementalConversion : public mlir::OpRewritePattern<hlfir::SumOp> {
}

private:
static llvm::SmallVector<mlir::Value>
genArrayExtents(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity array) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();
return inExtents;
}

// Return fir.shape specifying the shape of the result
// of a SUM reduction with DIM=dimVal. The second return value
// is the extent of the DIM dimension.
static std::tuple<mlir::Value, mlir::Value>
genResultShape(mlir::Location loc, fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
mlir::Value inShape = hlfir::genShape(loc, builder, array);
genResultShapeForPartialReduction(mlir::Location loc,
fir::FirOpBuilder &builder,
hlfir::Entity array, int64_t dimVal) {
llvm::SmallVector<mlir::Value> inExtents =
hlfir::getExplicitExtentsFromShape(inShape, builder);
genArrayExtents(loc, builder, array);
assert(dimVal > 0 && dimVal <= static_cast<int64_t>(inExtents.size()) &&
"DIM must be present and a positive constant not exceeding "
"the array's rank");
if (inShape.getUses().empty())
inShape.getDefiningOp()->erase();

mlir::Value dimExtent = inExtents[dimVal - 1];
inExtents.erase(inExtents.begin() + dimVal - 1);
Expand Down Expand Up @@ -355,22 +377,22 @@ class SimplifyHLFIRIntrinsics
target.addDynamicallyLegalOp<hlfir::SumOp>([](hlfir::SumOp sum) {
if (!simplifySum)
return true;
if (mlir::Value dim = sum.getDim()) {
if (auto dimVal = fir::getIntIfConstant(dim)) {
if (!fir::isa_trivial(sum.getType())) {
// Ignore the case SUM(a, DIM=X), where 'a' is a 1D array.
// It is only legal when X is 1, and it should probably be
// canonicalized into SUM(a).
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(
sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}

// Always inline total reductions.
if (hlfir::Entity{sum}.getRank() == 0)
return false;
mlir::Value dim = sum.getDim();
if (!dim)
return false;

if (auto dimVal = fir::getIntIfConstant(dim)) {
fir::SequenceType arrayTy = mlir::cast<fir::SequenceType>(
hlfir::getFortranElementOrSequenceType(sum.getArray().getType()));
if (*dimVal > 0 && *dimVal <= arrayTy.getDimension()) {
// Ignore SUMs with illegal DIM values.
// They may appear in dead code,
// and they do not have to be converted.
return false;
}
}
return true;
Expand Down
Loading
Loading