Skip to content

Commit 058fb57

Browse files
committed
Only handle single-block regions for now
1 parent e3c2884 commit 058fb57

File tree

3 files changed

+143
-122
lines changed

3 files changed

+143
-122
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313
#include <flang/Optimizer/Dialect/FIRType.h>
1414
#include <flang/Optimizer/HLFIR/HLFIROps.h>
1515
#include <flang/Optimizer/OpenMP/Passes.h>
16+
#include <llvm/ADT/BreadthFirstIterator.h>
1617
#include <llvm/ADT/STLExtras.h>
1718
#include <llvm/ADT/SmallVectorExtras.h>
1819
#include <llvm/ADT/iterator_range.h>
1920
#include <llvm/Support/ErrorHandling.h>
2021
#include <mlir/Dialect/Arith/IR/Arith.h>
2122
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
23+
#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
2224
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
2325
#include <mlir/Dialect/SCF/IR/SCF.h>
2426
#include <mlir/IR/BuiltinOps.h>
@@ -161,7 +163,8 @@ static void cleanupBlock(Block *block) {
161163
}
162164

163165
static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
164-
IRMapping &rootMapping, Location loc) {
166+
IRMapping &rootMapping, Location loc,
167+
mlir::DominanceInfo &di) {
165168
OpBuilder rootBuilder(sourceRegion.getContext());
166169
ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
167170
OpBuilder copyFuncBuilder(m.getBodyRegion());
@@ -214,14 +217,19 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
214217
return copyPrivate;
215218
};
216219

217-
// TODO Need to handle these (clone them) in dominator tree order
218220
for (Block &block : sourceRegion) {
219-
rootBuilder.createBlock(
221+
Block *targetBlock = rootBuilder.createBlock(
220222
&targetRegion, {}, block.getArgumentTypes(),
221223
llvm::map_to_vector(block.getArguments(),
222224
[](BlockArgument arg) { return arg.getLoc(); }));
223-
Operation *terminator = block.getTerminator();
225+
rootMapping.map(&block, targetBlock);
226+
rootMapping.map(block.getArguments(), targetBlock->getArguments());
227+
}
224228

229+
auto handleOneBlock = [&](Block &block) {
230+
Block &targetBlock = *rootMapping.lookup(&block);
231+
rootBuilder.setInsertionPointToStart(&targetBlock);
232+
Operation *terminator = block.getTerminator();
225233
SmallVector<std::variant<SingleRegion, Operation *>> regions;
226234

227235
auto it = block.begin();
@@ -298,12 +306,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
298306
Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
299307
for (auto [region, clonedRegion] :
300308
llvm::zip(op->getRegions(), cloned->getRegions()))
301-
parallelizeRegion(region, clonedRegion, rootMapping, loc);
309+
parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
302310
}
303311
}
304312
}
305313

306314
rootBuilder.clone(*block.getTerminator(), rootMapping);
315+
};
316+
317+
if (sourceRegion.hasOneBlock()) {
318+
handleOneBlock(sourceRegion.front());
319+
} else {
320+
auto &domTree = di.getDomTree(&sourceRegion);
321+
for (auto node : llvm::breadth_first(domTree.getRootNode())) {
322+
handleOneBlock(*node->getBlock());
323+
}
307324
}
308325

309326
for (Block &targetBlock : targetRegion)
@@ -336,47 +353,46 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
336353
///
337354
/// Note that we allocate temporary memory for values in omp.single's which need
338355
/// to be accessed in all threads in the closest omp.parallel
339-
void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
356+
LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
340357
Location loc = wsOp->getLoc();
341358
IRMapping rootMapping;
342359

343360
OpBuilder rootBuilder(wsOp);
344361

345-
// TODO We need something like an scf;execute here, but that is not registered
346-
// so using fir.if for now but it looks like it does not support multiple
347-
// blocks so it doesnt work for multi block case...
348-
auto ifOp = rootBuilder.create<fir::IfOp>(
349-
loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false);
350-
ifOp.getThenRegion().front().erase();
351-
352-
parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc);
353-
354-
Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator();
355-
assert(isa<omp::TerminatorOp>(terminatorOp));
356-
OpBuilder termBuilder(terminatorOp);
357-
362+
// TODO We need something like an scf.execute here, but that is not registered
363+
// so using omp.workshare as a placeholder. We need this op as our
364+
// parallelizeRegion works on regions and not blocks.
365+
omp::WorkshareOp newOp =
366+
rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
358367
if (!wsOp.getNowait())
359-
termBuilder.create<omp::BarrierOp>(loc);
360-
361-
termBuilder.create<fir::ResultOp>(loc, ValueRange());
362-
363-
terminatorOp->erase();
368+
rootBuilder.create<omp::BarrierOp>(loc);
369+
370+
parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di);
371+
372+
if (wsOp.getRegion().getBlocks().size() != 1)
373+
return failure();
374+
375+
// Inline the contents of the placeholder workshare op into its parent block.
376+
Block *theBlock = &newOp.getRegion().front();
377+
Operation *term = theBlock->getTerminator();
378+
Block *parentBlock = wsOp->getBlock();
379+
parentBlock->getOperations().splice(newOp->getIterator(),
380+
theBlock->getOperations());
381+
assert(term->getNumOperands() == 0);
382+
term->erase();
383+
newOp->erase();
364384
wsOp->erase();
365-
366-
return;
385+
return success();
367386
}
368387

369388
class LowerWorksharePass
370389
: public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
371390
public:
372391
void runOnOperation() override {
373-
SmallPtrSet<Operation *, 8> parents;
392+
mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
374393
getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
375-
Operation *isolatedParent =
376-
wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
377-
parents.insert(isolatedParent);
378-
379-
lowerWorkshare(wsOp);
394+
if (failed(lowerWorkshare(wsOp, di)))
395+
signalPassFailure();
380396
});
381397
}
382398
};

0 commit comments

Comments
 (0)