|
13 | 13 | #include <flang/Optimizer/Dialect/FIRType.h>
|
14 | 14 | #include <flang/Optimizer/HLFIR/HLFIROps.h>
|
15 | 15 | #include <flang/Optimizer/OpenMP/Passes.h>
|
| 16 | +#include <llvm/ADT/BreadthFirstIterator.h> |
16 | 17 | #include <llvm/ADT/STLExtras.h>
|
17 | 18 | #include <llvm/ADT/SmallVectorExtras.h>
|
18 | 19 | #include <llvm/ADT/iterator_range.h>
|
19 | 20 | #include <llvm/Support/ErrorHandling.h>
|
20 | 21 | #include <mlir/Dialect/Arith/IR/Arith.h>
|
21 | 22 | #include <mlir/Dialect/LLVMIR/LLVMTypes.h>
|
| 23 | +#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h> |
22 | 24 | #include <mlir/Dialect/OpenMP/OpenMPDialect.h>
|
23 | 25 | #include <mlir/Dialect/SCF/IR/SCF.h>
|
24 | 26 | #include <mlir/IR/BuiltinOps.h>
|
@@ -161,7 +163,8 @@ static void cleanupBlock(Block *block) {
|
161 | 163 | }
|
162 | 164 |
|
163 | 165 | static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
|
164 |
| - IRMapping &rootMapping, Location loc) { |
| 166 | + IRMapping &rootMapping, Location loc, |
| 167 | + mlir::DominanceInfo &di) { |
165 | 168 | OpBuilder rootBuilder(sourceRegion.getContext());
|
166 | 169 | ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
|
167 | 170 | OpBuilder copyFuncBuilder(m.getBodyRegion());
|
@@ -214,14 +217,19 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
|
214 | 217 | return copyPrivate;
|
215 | 218 | };
|
216 | 219 |
|
217 |
| - // TODO Need to handle these (clone them) in dominator tree order |
218 | 220 | for (Block &block : sourceRegion) {
|
219 |
| - rootBuilder.createBlock( |
| 221 | + Block *targetBlock = rootBuilder.createBlock( |
220 | 222 | &targetRegion, {}, block.getArgumentTypes(),
|
221 | 223 | llvm::map_to_vector(block.getArguments(),
|
222 | 224 | [](BlockArgument arg) { return arg.getLoc(); }));
|
223 |
| - Operation *terminator = block.getTerminator(); |
| 225 | + rootMapping.map(&block, targetBlock); |
| 226 | + rootMapping.map(block.getArguments(), targetBlock->getArguments()); |
| 227 | + } |
224 | 228 |
|
| 229 | + auto handleOneBlock = [&](Block &block) { |
| 230 | + Block &targetBlock = *rootMapping.lookup(&block); |
| 231 | + rootBuilder.setInsertionPointToStart(&targetBlock); |
| 232 | + Operation *terminator = block.getTerminator(); |
225 | 233 | SmallVector<std::variant<SingleRegion, Operation *>> regions;
|
226 | 234 |
|
227 | 235 | auto it = block.begin();
|
@@ -298,12 +306,21 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
|
298 | 306 | Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
|
299 | 307 | for (auto [region, clonedRegion] :
|
300 | 308 | llvm::zip(op->getRegions(), cloned->getRegions()))
|
301 |
| - parallelizeRegion(region, clonedRegion, rootMapping, loc); |
| 309 | + parallelizeRegion(region, clonedRegion, rootMapping, loc, di); |
302 | 310 | }
|
303 | 311 | }
|
304 | 312 | }
|
305 | 313 |
|
306 | 314 | rootBuilder.clone(*block.getTerminator(), rootMapping);
|
| 315 | + }; |
| 316 | + |
| 317 | + if (sourceRegion.hasOneBlock()) { |
| 318 | + handleOneBlock(sourceRegion.front()); |
| 319 | + } else { |
| 320 | + auto &domTree = di.getDomTree(&sourceRegion); |
| 321 | + for (auto node : llvm::breadth_first(domTree.getRootNode())) { |
| 322 | + handleOneBlock(*node->getBlock()); |
| 323 | + } |
307 | 324 | }
|
308 | 325 |
|
309 | 326 | for (Block &targetBlock : targetRegion)
|
@@ -336,47 +353,46 @@ static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
|
336 | 353 | ///
|
337 | 354 | /// Note that we allocate temporary memory for values in omp.single's which need
|
338 | 355 | /// to be accessed in all threads in the closest omp.parallel
|
339 |
| -void lowerWorkshare(mlir::omp::WorkshareOp wsOp) { |
| 356 | +LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) { |
340 | 357 | Location loc = wsOp->getLoc();
|
341 | 358 | IRMapping rootMapping;
|
342 | 359 |
|
343 | 360 | OpBuilder rootBuilder(wsOp);
|
344 | 361 |
|
345 |
| - // TODO We need something like an scf;execute here, but that is not registered |
346 |
| - // so using fir.if for now but it looks like it does not support multiple |
347 |
| - // blocks so it doesnt work for multi block case... |
348 |
| - auto ifOp = rootBuilder.create<fir::IfOp>( |
349 |
| - loc, rootBuilder.create<arith::ConstantIntOp>(loc, 1, 1), false); |
350 |
| - ifOp.getThenRegion().front().erase(); |
351 |
| - |
352 |
| - parallelizeRegion(wsOp.getRegion(), ifOp.getThenRegion(), rootMapping, loc); |
353 |
| - |
354 |
| - Operation *terminatorOp = ifOp.getThenRegion().back().getTerminator(); |
355 |
| - assert(isa<omp::TerminatorOp>(terminatorOp)); |
356 |
| - OpBuilder termBuilder(terminatorOp); |
357 |
| - |
| 362 | + // TODO We need something like an scf.execute here, but that is not registered |
| 363 | + // so using omp.workshare as a placeholder. We need this op as our |
| 364 | + // parallelizeRegion works on regions and not blocks. |
| 365 | + omp::WorkshareOp newOp = |
| 366 | + rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands()); |
358 | 367 | if (!wsOp.getNowait())
|
359 |
| - termBuilder.create<omp::BarrierOp>(loc); |
360 |
| - |
361 |
| - termBuilder.create<fir::ResultOp>(loc, ValueRange()); |
362 |
| - |
363 |
| - terminatorOp->erase(); |
| 368 | + rootBuilder.create<omp::BarrierOp>(loc); |
| 369 | + |
| 370 | + parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di); |
| 371 | + |
| 372 | + if (wsOp.getRegion().getBlocks().size() != 1) |
| 373 | + return failure(); |
| 374 | + |
| 375 | + // Inline the contents of the placeholder workshare op into its parent block. |
| 376 | + Block *theBlock = &newOp.getRegion().front(); |
| 377 | + Operation *term = theBlock->getTerminator(); |
| 378 | + Block *parentBlock = wsOp->getBlock(); |
| 379 | + parentBlock->getOperations().splice(newOp->getIterator(), |
| 380 | + theBlock->getOperations()); |
| 381 | + assert(term->getNumOperands() == 0); |
| 382 | + term->erase(); |
| 383 | + newOp->erase(); |
364 | 384 | wsOp->erase();
|
365 |
| - |
366 |
| - return; |
| 385 | + return success(); |
367 | 386 | }
|
368 | 387 |
|
369 | 388 | class LowerWorksharePass
|
370 | 389 | : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
|
371 | 390 | public:
|
372 | 391 | void runOnOperation() override {
|
373 |
| - SmallPtrSet<Operation *, 8> parents; |
| 392 | + mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>(); |
374 | 393 | getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
|
375 |
| - Operation *isolatedParent = |
376 |
| - wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>(); |
377 |
| - parents.insert(isolatedParent); |
378 |
| - |
379 |
| - lowerWorkshare(wsOp); |
| 394 | + if (failed(lowerWorkshare(wsOp, di))) |
| 395 | + signalPassFailure(); |
380 | 396 | });
|
381 | 397 | }
|
382 | 398 | };
|
|
0 commit comments