Skip to content

Commit d5fbe9c

Browse files
committed
[flang] Introduce ws loop nest generation for HLFIR lowering
Emit loop nests in a custom wrapper Only emit unordered loops as omp loops Fix uninitialized memory bug in genLoopNest
1 parent de32599 commit d5fbe9c

File tree

7 files changed

+69
-43
lines changed

7 files changed

+69
-43
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -357,20 +357,22 @@ hlfir::ElementalOp genElementalOp(
357357

358358
/// Structure to describe a loop nest.
359359
struct LoopNest {
360-
fir::DoLoopOp outerLoop;
361-
fir::DoLoopOp innerLoop;
360+
mlir::Operation *outerOp = nullptr;
361+
mlir::Block *body = nullptr;
362362
llvm::SmallVector<mlir::Value> oneBasedIndices;
363363
};
364364

365365
/// Generate a fir.do_loop nest looping from 1 to extents[i].
366366
/// \p isUnordered specifies whether the loops in the loop nest
367367
/// are unordered.
368368
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
369-
mlir::ValueRange extents, bool isUnordered = false);
369+
mlir::ValueRange extents, bool isUnordered = false,
370+
bool emitWorkshareLoop = false);
370371
inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
371-
mlir::Value shape, bool isUnordered = false) {
372+
mlir::Value shape, bool isUnordered = false,
373+
bool emitWorkshareLoop = false) {
372374
return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
373-
isUnordered);
375+
isUnordered, emitWorkshareLoop);
374376
}
375377

376378
/// Inline the body of an hlfir.elemental at the current insertion point

flang/lib/Lower/ConvertCall.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
21282128
hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
21292129
mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
21302130
auto insPt = builder.saveInsertionPoint();
2131-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
2131+
builder.setInsertionPointToStart(loopNest.body);
21322132
callContext.stmtCtx.pushScope();
21332133
for (auto &preparedActual : loweredActuals)
21342134
if (preparedActual)

flang/lib/Lower/OpenMP/ReductionProcessor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
375375
// know this won't miss any opportuinties for clever elemental inlining
376376
hlfir::LoopNest nest = hlfir::genLoopNest(
377377
loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
378-
builder.setInsertionPointToStart(nest.innerLoop.getBody());
378+
builder.setInsertionPointToStart(nest.body);
379379
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
380380
auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
381381
loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
@@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
389389
builder, loc, redId, refTy, lhsEle, rhsEle);
390390
builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
391391

392-
builder.setInsertionPointAfter(nest.outerLoop);
392+
builder.setInsertionPointAfter(nest.outerOp);
393393
builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
394394
}
395395

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "mlir/IR/IRMapping.h"
2121
#include "mlir/Support/LLVM.h"
2222
#include "llvm/ADT/TypeSwitch.h"
23+
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
2324
#include <optional>
2425

2526
// Return explicit extents. If the base is a fir.box, this won't read it to
@@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp(
855856

856857
hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
857858
fir::FirOpBuilder &builder,
858-
mlir::ValueRange extents, bool isUnordered) {
859+
mlir::ValueRange extents, bool isUnordered,
860+
bool emitWorkshareLoop) {
861+
emitWorkshareLoop = emitWorkshareLoop && isUnordered;
859862
hlfir::LoopNest loopNest;
860863
assert(!extents.empty() && "must have at least one extent");
861-
auto insPt = builder.saveInsertionPoint();
864+
mlir::OpBuilder::InsertionGuard guard(builder);
862865
loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
863866
// Build loop nest from column to row.
864867
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
865868
mlir::Type indexType = builder.getIndexType();
866-
unsigned dim = extents.size() - 1;
867-
for (auto extent : llvm::reverse(extents)) {
868-
auto ub = builder.createConvert(loc, indexType, extent);
869-
loopNest.innerLoop =
870-
builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
871-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
872-
// Reverse the indices so they are in column-major order.
873-
loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
874-
if (!loopNest.outerLoop)
875-
loopNest.outerLoop = loopNest.innerLoop;
869+
if (emitWorkshareLoop) {
870+
auto wslw = builder.create<mlir::omp::WorkshareLoopWrapperOp>(loc);
871+
loopNest.outerOp = wslw;
872+
builder.createBlock(&wslw.getRegion());
873+
mlir::omp::LoopNestOperands lnops;
874+
lnops.loopInclusive = builder.getUnitAttr();
875+
for (auto extent : llvm::reverse(extents)) {
876+
lnops.loopLowerBounds.push_back(one);
877+
lnops.loopUpperBounds.push_back(extent);
878+
lnops.loopSteps.push_back(one);
879+
}
880+
auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
881+
builder.create<mlir::omp::TerminatorOp>(loc);
882+
mlir::Block *block = builder.createBlock(&lnOp.getRegion());
883+
for (auto extent : llvm::reverse(extents))
884+
block->addArgument(extent.getType(), extent.getLoc());
885+
loopNest.body = block;
886+
builder.create<mlir::omp::YieldOp>(loc);
887+
for (unsigned dim = 0; dim < extents.size(); dim++)
888+
loopNest.oneBasedIndices[extents.size() - dim - 1] =
889+
lnOp.getRegion().front().getArgument(dim);
890+
} else {
891+
unsigned dim = extents.size() - 1;
892+
for (auto extent : llvm::reverse(extents)) {
893+
auto ub = builder.createConvert(loc, indexType, extent);
894+
auto doLoop =
895+
builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
896+
loopNest.body = doLoop.getBody();
897+
builder.setInsertionPointToStart(loopNest.body);
898+
// Reverse the indices so they are in column-major order.
899+
loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
900+
if (!loopNest.outerOp)
901+
loopNest.outerOp = doLoop;
902+
}
876903
}
877-
builder.restoreInsertionPoint(insPt);
878904
return loopNest;
879905
}
880906

flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
2727
#include "flang/Optimizer/HLFIR/HLFIROps.h"
2828
#include "flang/Optimizer/HLFIR/Passes.h"
29+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
2930
#include "mlir/IR/Dominance.h"
3031
#include "mlir/IR/PatternMatch.h"
3132
#include "mlir/Pass/Pass.h"
@@ -793,7 +794,7 @@ struct ElementalOpConversion
793794
hlfir::LoopNest loopNest =
794795
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
795796
auto insPt = builder.saveInsertionPoint();
796-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
797+
builder.setInsertionPointToStart(loopNest.body);
797798
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
798799
loopNest.oneBasedIndices);
799800
hlfir::Entity elementValue(yield.getElementValue());

flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
464464
// if the LHS is not).
465465
mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
466466
elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
467-
builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
467+
builder.setInsertionPointToStart(elementalLoopNest->body);
468468
lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
469469
elementalLoopNest->oneBasedIndices);
470470
rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
@@ -484,16 +484,15 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
484484
for (auto &cleanupConversion : argConversionCleanups)
485485
cleanupConversion();
486486
if (elementalLoopNest)
487-
builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
487+
builder.setInsertionPointAfter(elementalLoopNest->outerOp);
488488
} else {
489489
// TODO: preserve allocatable assignment aspects for forall once
490490
// they are conveyed in hlfir.region_assign.
491491
builder.create<hlfir::AssignOp>(loc, rhsEntity, lhsEntity);
492492
}
493493
generateCleanupIfAny(loweredLhs.elementalCleanup);
494494
if (loweredLhs.vectorSubscriptLoopNest)
495-
builder.setInsertionPointAfter(
496-
loweredLhs.vectorSubscriptLoopNest->outerLoop);
495+
builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
497496
generateCleanupIfAny(oldRhsYield);
498497
generateCleanupIfAny(loweredLhs.nonElementalCleanup);
499498
}
@@ -518,16 +517,16 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
518517
hlfir::Entity savedMask{maybeSaved->first};
519518
mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
520519
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
521-
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
522-
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
520+
constructStack.push_back(whereLoopNest->outerOp);
521+
builder.setInsertionPointToStart(whereLoopNest->body);
523522
mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
524523
whereLoopNest->oneBasedIndices);
525524
generateMaskIfOp(cdt);
526525
if (maybeSaved->second) {
527526
// If this is the same run as the one that saved the value, the clean-up
528527
// was left-over to be done now.
529528
auto insertionPoint = builder.saveInsertionPoint();
530-
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
529+
builder.setInsertionPointAfter(whereLoopNest->outerOp);
531530
generateCleanupIfAny(maybeSaved->second);
532531
builder.restoreInsertionPoint(insertionPoint);
533532
}
@@ -539,8 +538,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
539538
mask.generateNoneElementalPart(builder, mapper);
540539
mlir::Value shape = mask.generateShape(builder, mapper);
541540
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
542-
constructStack.push_back(whereLoopNest->outerLoop.getOperation());
543-
builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
541+
constructStack.push_back(whereLoopNest->outerOp);
542+
builder.setInsertionPointToStart(whereLoopNest->body);
544543
mlir::Value cdt = generateMaskedEntity(mask);
545544
generateMaskIfOp(cdt);
546545
return;
@@ -754,7 +753,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
754753
loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
755754
loc, builder, loweredLhs.vectorSubscriptShape.value());
756755
builder.setInsertionPointToStart(
757-
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
756+
loweredLhs.vectorSubscriptLoopNest->body);
758757
}
759758
loweredLhs.lhs = temp->second.fetch(loc, builder);
760759
return loweredLhs;
@@ -771,8 +770,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
771770
loweredLhs.vectorSubscriptLoopNest =
772771
hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
773772
!elementalAddrLhs.isOrdered());
774-
builder.setInsertionPointToStart(
775-
loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
773+
builder.setInsertionPointToStart(loweredLhs.vectorSubscriptLoopNest->body);
776774
mapper.map(elementalAddrLhs.getIndices(),
777775
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
778776
for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
@@ -798,11 +796,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
798796
if (!maskedExpr.noneElementalPartWasGenerated) {
799797
// Generate none elemental part before the where loops (but inside the
800798
// current forall loops if any).
801-
builder.setInsertionPoint(whereLoopNest->outerLoop);
799+
builder.setInsertionPoint(whereLoopNest->outerOp);
802800
maskedExpr.generateNoneElementalPart(builder, mapper);
803801
}
804802
// Generate the none elemental part cleanup after the where loops.
805-
builder.setInsertionPointAfter(whereLoopNest->outerLoop);
803+
builder.setInsertionPointAfter(whereLoopNest->outerOp);
806804
maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
807805
// Generate the value of the current element for the masked expression
808806
// at the current insertion point (inside the where loops, and any fir.if
@@ -1242,7 +1240,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
12421240
LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
12431241
fir::factory::TemporaryStorage *temp = nullptr;
12441242
if (loweredLhs.vectorSubscriptLoopNest)
1245-
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
1243+
constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
12461244
if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
12471245
// Vector subscripted entity for which the shape must also be saved on top
12481246
// of the element addresses (e.g. the shape may change in each forall
@@ -1265,7 +1263,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
12651263
// subscripted LHS.
12661264
auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
12671265
auto insertionPoint = builder.saveInsertionPoint();
1268-
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
1266+
builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
12691267
vectorTmp.pushShape(loc, builder, shape);
12701268
builder.restoreInsertionPoint(insertionPoint);
12711269
} else {
@@ -1290,8 +1288,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
12901288
generateCleanupIfAny(loweredLhs.elementalCleanup);
12911289
if (loweredLhs.vectorSubscriptLoopNest) {
12921290
constructStack.pop_back();
1293-
builder.setInsertionPointAfter(
1294-
loweredLhs.vectorSubscriptLoopNest->outerLoop);
1291+
builder.setInsertionPointAfter(loweredLhs.vectorSubscriptLoopNest->outerOp);
12951292
}
12961293
}
12971294

flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
483483
// hlfir.elemental region inside the inner loop
484484
hlfir::LoopNest loopNest =
485485
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
486-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
486+
builder.setInsertionPointToStart(loopNest.body);
487487
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
488488
loopNest.oneBasedIndices);
489489
hlfir::Entity elementValue{yield.getElementValue()};
@@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
554554
hlfir::getIndexExtents(loc, builder, shape);
555555
hlfir::LoopNest loopNest =
556556
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
557-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
557+
builder.setInsertionPointToStart(loopNest.body);
558558
auto arrayElement =
559559
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
560560
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
@@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
649649
hlfir::getIndexExtents(loc, builder, shape);
650650
hlfir::LoopNest loopNest =
651651
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
652-
builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
652+
builder.setInsertionPointToStart(loopNest.body);
653653
auto rhsArrayElement =
654654
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
655655
rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);

0 commit comments

Comments
 (0)