Skip to content

[flang] Introduce custom loop nest generation for loops in workshare construct #101445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions flang/include/flang/Optimizer/Builder/HLFIRTools.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,20 +357,22 @@ hlfir::ElementalOp genElementalOp(

/// Structure to describe a loop nest.
struct LoopNest {
fir::DoLoopOp outerLoop;
fir::DoLoopOp innerLoop;
mlir::Operation *outerOp = nullptr;
mlir::Block *body = nullptr;
llvm::SmallVector<mlir::Value> oneBasedIndices;
};

/// Generate a fir.do_loop nest looping from 1 to extents[i].
/// \p isUnordered specifies whether the loops in the loop nest
/// are unordered.
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::ValueRange extents, bool isUnordered = false);
mlir::ValueRange extents, bool isUnordered = false,
bool emitWorkshareLoop = false);
inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
mlir::Value shape, bool isUnordered = false) {
mlir::Value shape, bool isUnordered = false,
bool emitWorkshareLoop = false) {
return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
isUnordered);
isUnordered, emitWorkshareLoop);
}

/// Inline the body of an hlfir.elemental at the current insertion point
Expand Down
2 changes: 1 addition & 1 deletion flang/lib/Lower/ConvertCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2135,7 +2135,7 @@ class ElementalCallBuilder {
hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
callContext.stmtCtx.pushScope();
for (auto &preparedActual : loweredActuals)
if (preparedActual)
Expand Down
4 changes: 2 additions & 2 deletions flang/lib/Lower/OpenMP/ReductionProcessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
// know this won't miss any opportuinties for clever elemental inlining
hlfir::LoopNest nest = hlfir::genLoopNest(
loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
builder.setInsertionPointToStart(nest.innerLoop.getBody());
builder.setInsertionPointToStart(nest.body);
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
Expand All @@ -388,7 +388,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder, loc, redId, refTy, lhsEle, rhsEle);
builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);

builder.setInsertionPointAfter(nest.outerLoop);
builder.setInsertionPointAfter(nest.outerOp);
builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
}

Expand Down
51 changes: 38 additions & 13 deletions flang/lib/Optimizer/Builder/HLFIRTools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
#include <optional>

// Return explicit extents. If the base is a fir.box, this won't read it to
Expand Down Expand Up @@ -855,26 +856,50 @@ mlir::Value hlfir::inlineElementalOp(

hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
fir::FirOpBuilder &builder,
mlir::ValueRange extents, bool isUnordered) {
mlir::ValueRange extents, bool isUnordered,
bool emitWorkshareLoop) {
emitWorkshareLoop = emitWorkshareLoop && isUnordered;
hlfir::LoopNest loopNest;
assert(!extents.empty() && "must have at least one extent");
auto insPt = builder.saveInsertionPoint();
mlir::OpBuilder::InsertionGuard guard(builder);
loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
// Build loop nest from column to row.
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
mlir::Type indexType = builder.getIndexType();
unsigned dim = extents.size() - 1;
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);
loopNest.innerLoop =
builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
// Reverse the indices so they are in column-major order.
loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
if (!loopNest.outerLoop)
loopNest.outerLoop = loopNest.innerLoop;
if (emitWorkshareLoop) {
auto wslw = builder.create<mlir::omp::WorkshareLoopWrapperOp>(loc);
loopNest.outerOp = wslw;
builder.createBlock(&wslw.getRegion());
mlir::omp::LoopNestOperands lnops;
lnops.loopInclusive = builder.getUnitAttr();
for (auto extent : llvm::reverse(extents)) {
lnops.loopLowerBounds.push_back(one);
lnops.loopUpperBounds.push_back(extent);
lnops.loopSteps.push_back(one);
}
auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
mlir::Block *block = builder.createBlock(&lnOp.getRegion());
for (auto extent : llvm::reverse(extents))
block->addArgument(extent.getType(), extent.getLoc());
loopNest.body = block;
builder.create<mlir::omp::YieldOp>(loc);
for (unsigned dim = 0; dim < extents.size(); dim++)
loopNest.oneBasedIndices[extents.size() - dim - 1] =
lnOp.getRegion().front().getArgument(dim);
} else {
unsigned dim = extents.size() - 1;
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);
auto doLoop =
builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
loopNest.body = doLoop.getBody();
builder.setInsertionPointToStart(loopNest.body);
// Reverse the indices so they are in column-major order.
loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
if (!loopNest.outerOp)
loopNest.outerOp = doLoop;
}
}
builder.restoreInsertionPoint(insPt);
return loopNest;
}

Expand Down
3 changes: 2 additions & 1 deletion flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
Expand Down Expand Up @@ -793,7 +794,7 @@ struct ElementalOpConversion
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue(yield.getElementValue());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
// if the LHS is not).
mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
builder.setInsertionPointToStart(elementalLoopNest->body);
lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
elementalLoopNest->oneBasedIndices);
rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
Expand All @@ -484,16 +484,15 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
for (auto &cleanupConversion : argConversionCleanups)
cleanupConversion();
if (elementalLoopNest)
builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
builder.setInsertionPointAfter(elementalLoopNest->outerOp);
} else {
// TODO: preserve allocatable assignment aspects for forall once
// they are conveyed in hlfir.region_assign.
builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity);
}
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest)
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
generateCleanupIfAny(oldRhsYield);
generateCleanupIfAny(loweredLhs.nonElementalCleanup);
}
Expand All @@ -518,16 +517,16 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
hlfir::Entity savedMask{maybeSaved->first};
mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
constructStack.push_back(whereLoopNest->outerOp);
builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
whereLoopNest->oneBasedIndices);
generateMaskIfOp(cdt);
if (maybeSaved->second) {
// If this is the same run as the one that saved the value, the clean-up
// was left-over to be done now.
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
builder.setInsertionPointAfter(whereLoopNest->outerOp);
generateCleanupIfAny(maybeSaved->second);
builder.restoreInsertionPoint(insertionPoint);
}
Expand All @@ -539,8 +538,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
mask.generateNoneElementalPart(builder, mapper);
mlir::Value shape = mask.generateShape(builder, mapper);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
constructStack.push_back(whereLoopNest->outerOp);
builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = generateMaskedEntity(mask);
generateMaskIfOp(cdt);
return;
Expand Down Expand Up @@ -754,7 +753,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
loc, builder, loweredLhs.vectorSubscriptShape.value());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
loweredLhs.vectorSubscriptLoopNest->body);
}
loweredLhs.lhs = temp->second.fetch(loc, builder);
return loweredLhs;
Expand All @@ -771,8 +770,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
loweredLhs.vectorSubscriptLoopNest =
hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
!elementalAddrLhs.isOrdered());
builder.setInsertionPointToStart(
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
builder.setInsertionPointToStart(loweredLhs.vectorSubscriptLoopNest->body);
mapper.map(elementalAddrLhs.getIndices(),
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
Expand All @@ -798,11 +796,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
if (!maskedExpr.noneElementalPartWasGenerated) {
// Generate none elemental part before the where loops (but inside the
// current forall loops if any).
builder.setInsertionPoint(whereLoopNest->outerLoop);
builder.setInsertionPoint(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalPart(builder, mapper);
}
// Generate the none elemental part cleanup after the where loops.
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
builder.setInsertionPointAfter(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
// Generate the value of the current element for the masked expression
// at the current insertion point (inside the where loops, and any fir.if
Expand Down Expand Up @@ -1242,7 +1240,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
fir::factory::TemporaryStorage *temp = nullptr;
if (loweredLhs.vectorSubscriptLoopNest)
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
// Vector subscripted entity for which the shape must also be saved on top
// of the element addresses (e.g. the shape may change in each forall
Expand All @@ -1265,7 +1263,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
// subscripted LHS.
auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
auto insertionPoint = builder.saveInsertionPoint();
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
vectorTmp.pushShape(loc, builder, shape);
builder.restoreInsertionPoint(insertionPoint);
} else {
Expand All @@ -1290,8 +1288,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest) {
constructStack.pop_back();
builder.setInsertionPointAfter(
loweredLhs.vectorSubscriptLoopNest->outerLoop);
builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
// hlfir.elemental region inside the inner loop
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue{yield.getElementValue()};
Expand Down Expand Up @@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
Expand Down Expand Up @@ -652,7 +652,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
builder.setInsertionPointToStart(loopNest.body);
auto rhsArrayElement =
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
Expand Down
Loading