Skip to content

Commit 4da93bb

Browse files
committed
[flang] Introduce ws loop nest generation for HLFIR lowering
1 parent 8068d60 commit 4da93bb

File tree

7 files changed

+69
-40
lines changed

7 files changed

+69
-40
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,20 +357,22 @@ hlfir::ElementalOp genElementalOp(
357357

358358
/// Structure to describe a loop nest.
359359
struct LoopNest {
360-
fir::DoLoopOp outerLoop;
361-
fir::DoLoopOp innerLoop;
360+
mlir::Operation *outerOp;
361+
mlir::Block *body;
362362
llvm::SmallVector<mlir::Value> oneBasedIndices;
363363
};
364364

365365
/// Generate a fir.do_loop nest looping from 1 to extents[i].
366366
/// \p isUnordered specifies whether the loops in the loop nest
367367
/// are unordered.
368368
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
369-
mlir::ValueRange extents, bool isUnordered = false);
369+
mlir::ValueRange extents, bool isUnordered = false,
370+
bool emitWsLoop = false);
370371
inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
371-
mlir::Value shape, bool isUnordered = false) {
372+
mlir::Value shape, bool isUnordered = false,
373+
bool emitWsLoop = false) {
372374
return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
373-
isUnordered);
375+
isUnordered, emitWsLoop);
374376
}
375377

376378
/// Inline the body of an hlfir.elemental at the current insertion point

flang/lib/Lower/ConvertCall.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
21282128
hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
21292129
mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
21302130
auto insPt = builder.saveInsertionPoint();
2131-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
2131+
builder.setInsertionPointToStart(loopNest.body);
21322132
callContext.stmtCtx.pushScope();
21332133
for (auto &preparedActual : loweredActuals)
21342134
if (preparedActual)

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
375375
// know this won't miss any opportuinties for clever elemental inlining
376376
hlfir::LoopNest nest = hlfir::genLoopNest(
377377
loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
378-
builder.setInsertionPointToStart(nest.innerLoop.getBody());
378+
builder.setInsertionPointToStart(nest.body);
379379
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
380380
auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
381381
loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
@@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
389389
builder, loc, redId, refTy, lhsEle, rhsEle);
390390
builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
391391

392-
builder.setInsertionPointAfter(nest.outerLoop);
392+
builder.setInsertionPointAfter(nest.outerOp);
393393
builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
394394
}
395395

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/IRMapping.h"
2121
#include "mlir/Support/LLVM.h"
2222
#include "llvm/ADT/TypeSwitch.h"
23+
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
2324
#include <optional>
2425

2526
// Return explicit extents. If the base is a fir.box, this won't read it to
@@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp(
855856

856857
hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
857858
fir::FirOpBuilder &builder,
858-
mlir::ValueRange extents, bool isUnordered) {
859+
mlir::ValueRange extents, bool isUnordered,
860+
bool emitWsLoop) {
859861
hlfir::LoopNest loopNest;
860862
assert(!extents.empty() && "must have at least one extent");
861-
auto insPt = builder.saveInsertionPoint();
863+
mlir::OpBuilder::InsertionGuard guard(builder);
862864
loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
863865
// Build loop nest from column to row.
864866
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
865867
mlir::Type indexType = builder.getIndexType();
866-
unsigned dim = extents.size() - 1;
867-
for (auto extent : llvm::reverse(extents)) {
868-
auto ub = builder.createConvert(loc, indexType, extent);
869-
loopNest.innerLoop =
870-
builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
871-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
872-
// Reverse the indices so they are in column-major order.
873-
loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
874-
if (!loopNest.outerLoop)
875-
loopNest.outerLoop = loopNest.innerLoop;
868+
if (emitWsLoop) {
869+
auto wsloop = builder.create<mlir::omp::WsloopOp>(
870+
loc, mlir::ArrayRef<mlir::NamedAttribute>());
871+
loopNest.outerOp = wsloop;
872+
builder.createBlock(&wsloop.getRegion());
873+
mlir::omp::LoopNestOperands lnops;
874+
lnops.loopInclusive = builder.getUnitAttr();
875+
for (auto extent : llvm::reverse(extents)) {
876+
lnops.loopLowerBounds.push_back(one);
877+
lnops.loopUpperBounds.push_back(extent);
878+
lnops.loopSteps.push_back(one);
879+
}
880+
auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
881+
builder.create<mlir::omp::TerminatorOp>(loc);
882+
mlir::Block *block = builder.createBlock(&lnOp.getRegion());
883+
for (auto extent : llvm::reverse(extents))
884+
block->addArgument(extent.getType(), extent.getLoc());
885+
loopNest.body = block;
886+
builder.create<mlir::omp::YieldOp>(loc);
887+
for (unsigned dim = 0; dim < extents.size(); dim++)
888+
loopNest.oneBasedIndices[extents.size() - dim - 1] =
889+
lnOp.getRegion().front().getArgument(dim);
890+
} else {
891+
unsigned dim = extents.size() - 1;
892+
for (auto extent : llvm::reverse(extents)) {
893+
auto ub = builder.createConvert(loc, indexType, extent);
894+
auto doLoop =
895+
builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
896+
loopNest.body = doLoop.getBody();
897+
builder.setInsertionPointToStart(loopNest.body);
898+
// Reverse the indices so they are in column-major order.
899+
loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
900+
if (!loopNest.outerOp)
901+
loopNest.outerOp = doLoop;
902+
}
876903
}
877-
builder.restoreInsertionPoint(insPt);
878904
return loopNest;
879905
}
880906

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "mlir/Pass/Pass.h"
3232
#include "mlir/Pass/PassManager.h"
3333
#include "mlir/Transforms/DialectConversion.h"
34+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
3435
#include "llvm/ADT/TypeSwitch.h"
3536

3637
namespace hlfir {
@@ -793,7 +794,7 @@ struct ElementalOpConversion
793794
hlfir::LoopNest loopNest =
794795
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
795796
auto insPt = builder.saveInsertionPoint();
796-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
797+
builder.setInsertionPointToStart(loopNest.body);
797798
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
798799
loopNest.oneBasedIndices);
799800
hlfir::Entity elementValue(yield.getElementValue());

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
464464
// if the LHS is not).
465465
mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
466466
elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
467-
builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
467+
builder.setInsertionPointToStart(elementalLoopNest->body);
468468
lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
469469
elementalLoopNest->oneBasedIndices);
470470
rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
@@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
484484
for (auto &cleanupConversion : argConversionCleanups)
485485
cleanupConversion();
486486
if (elementalLoopNest)
487-
builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
487+
builder.setInsertionPointAfter(elementalLoopNest->outerOp);
488488
} else {
489489
// TODO: preserve allocatable assignment aspects for forall once
490490
// they are conveyed in hlfir.region_assign.
@@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
493493
generateCleanupIfAny(loweredLhs.elementalCleanup);
494494
if (loweredLhs.vectorSubscriptLoopNest)
495495
builder.setInsertionPointAfter(
496-
loweredLhs.vectorSubscriptLoopNest->outerLoop);
496+
loweredLhs.vectorSubscriptLoopNest->outerOp);
497497
generateCleanupIfAny(oldRhsYield);
498498
generateCleanupIfAny(loweredLhs.nonElementalCleanup);
499499
}
@@ -518,16 +518,16 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
518518
hlfir::Entity savedMask{maybeSaved->first};
519519
mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
520520
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
521-
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
522-
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
521+
constructStack.push_back(whereLoopNest->outerOp);
522+
builder.setInsertionPointToStart(whereLoopNest->body);
523523
mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
524524
whereLoopNest->oneBasedIndices);
525525
generateMaskIfOp(cdt);
526526
if (maybeSaved->second) {
527527
// If this is the same run as the one that saved the value, the clean-up
528528
// was left-over to be done now.
529529
auto insertionPoint = builder.saveInsertionPoint();
530-
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
530+
builder.setInsertionPointAfter(whereLoopNest->outerOp);
531531
generateCleanupIfAny(maybeSaved->second);
532532
builder.restoreInsertionPoint(insertionPoint);
533533
}
@@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
539539
mask.generateNoneElementalPart(builder, mapper);
540540
mlir::Value shape = mask.generateShape(builder, mapper);
541541
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
542-
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
543-
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
542+
constructStack.push_back(whereLoopNest->outerOp);
543+
builder.setInsertionPointToStart(whereLoopNest->body);
544544
mlir::Value cdt = generateMaskedEntity(mask);
545545
generateMaskIfOp(cdt);
546546
return;
@@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
754754
loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
755755
loc, builder, loweredLhs.vectorSubscriptShape.value());
756756
builder.setInsertionPointToStart(
757-
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
757+
loweredLhs.vectorSubscriptLoopNest->body);
758758
}
759759
loweredLhs.lhs = temp->second.fetch(loc, builder);
760760
return loweredLhs;
@@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
772772
hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
773773
!elementalAddrLhs.isOrdered());
774774
builder.setInsertionPointToStart(
775-
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
775+
loweredLhs.vectorSubscriptLoopNest->body);
776776
mapper.map(elementalAddrLhs.getIndices(),
777777
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
778778
for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
@@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
798798
if (!maskedExpr.noneElementalPartWasGenerated) {
799799
// Generate none elemental part before the where loops (but inside the
800800
// current forall loops if any).
801-
builder.setInsertionPoint(whereLoopNest->outerLoop);
801+
builder.setInsertionPoint(whereLoopNest->outerOp);
802802
maskedExpr.generateNoneElementalPart(builder, mapper);
803803
}
804804
// Generate the none elemental part cleanup after the where loops.
805-
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
805+
builder.setInsertionPointAfter(whereLoopNest->outerOp);
806806
maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
807807
// Generate the value of the current element for the masked expression
808808
// at the current insertion point (inside the where loops, and any fir.if
@@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
12421242
LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
12431243
fir::factory::TemporaryStorage *temp = nullptr;
12441244
if (loweredLhs.vectorSubscriptLoopNest)
1245-
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
1245+
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
12461246
if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
12471247
// Vector subscripted entity for which the shape must also be saved on top
12481248
// of the element addresses (e.g. the shape may change in each forall
@@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
12651265
// subscripted LHS.
12661266
auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
12671267
auto insertionPoint = builder.saveInsertionPoint();
1268-
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
1268+
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
12691269
vectorTmp.pushShape(loc, builder, shape);
12701270
builder.restoreInsertionPoint(insertionPoint);
12711271
} else {
@@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
12911291
if (loweredLhs.vectorSubscriptLoopNest) {
12921292
constructStack.pop_back();
12931293
builder.setInsertionPointAfter(
1294-
loweredLhs.vectorSubscriptLoopNest->outerLoop);
1294+
loweredLhs.vectorSubscriptLoopNest->outerOp);
12951295
}
12961296
}
12971297

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
483483
// hlfir.elemental region inside the inner loop
484484
hlfir::LoopNest loopNest =
485485
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
486-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
486+
builder.setInsertionPointToStart(loopNest.body);
487487
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
488488
loopNest.oneBasedIndices);
489489
hlfir::Entity elementValue{yield.getElementValue()};
@@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
554554
hlfir::getIndexExtents(loc, builder, shape);
555555
hlfir::LoopNest loopNest =
556556
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
557-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
557+
builder.setInsertionPointToStart(loopNest.body);
558558
auto arrayElement =
559559
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
560560
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
@@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
649649
hlfir::getIndexExtents(loc, builder, shape);
650650
hlfir::LoopNest loopNest =
651651
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
652-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
652+
builder.setInsertionPointToStart(loopNest.body);
653653
auto rhsArrayElement =
654654
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
655655
rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);

0 commit comments

Comments
 (0)