|
25 | 25 | #include "llvm/ADT/SetVector.h"
|
26 | 26 | #include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
|
27 | 27 | #include "llvm/IR/BasicBlock.h"
|
| 28 | +#include "llvm/IR/CFG.h" |
28 | 29 | #include "llvm/IR/Constants.h"
|
29 | 30 | #include "llvm/IR/DerivedTypes.h"
|
30 | 31 | #include "llvm/IR/IRBuilder.h"
|
31 | 32 | #include "llvm/IR/LLVMContext.h"
|
32 | 33 | #include "llvm/IR/Module.h"
|
| 34 | +#include "llvm/Transforms/Utils/BasicBlockUtils.h" |
33 | 35 | #include "llvm/Transforms/Utils/Cloning.h"
|
34 | 36 |
|
35 | 37 | using namespace mlir;
|
@@ -304,7 +306,160 @@ ModuleTranslation::ModuleTranslation(Operation *module,
|
304 | 306 | assert(satisfiesLLVMModule(mlirModule) &&
|
305 | 307 | "mlirModule should honor LLVM's module semantics.");
|
306 | 308 | }
|
307 |
| -ModuleTranslation::~ModuleTranslation() {} |
| 309 | +ModuleTranslation::~ModuleTranslation() { |
| 310 | + if (ompBuilder) |
| 311 | + ompBuilder->finalize(); |
| 312 | +} |
| 313 | + |
| 314 | +/// Get the SSA value passed to the current block from the terminator operation |
| 315 | +/// of its predecessor. |
| 316 | +static Value getPHISourceValue(Block *current, Block *pred, |
| 317 | + unsigned numArguments, unsigned index) { |
| 318 | + Operation &terminator = *pred->getTerminator(); |
| 319 | + if (isa<LLVM::BrOp>(terminator)) |
| 320 | + return terminator.getOperand(index); |
| 321 | + |
| 322 | + // For conditional branches, we need to check if the current block is reached |
| 323 | + // through the "true" or the "false" branch and take the relevant operands. |
| 324 | + auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator); |
| 325 | + assert(condBranchOp && |
| 326 | + "only branch operations can be terminators of a block that " |
| 327 | + "has successors"); |
| 328 | + assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && |
| 329 | + "successors with arguments in LLVM conditional branches must be " |
| 330 | + "different blocks"); |
| 331 | + |
| 332 | + return condBranchOp.getSuccessor(0) == current |
| 333 | + ? condBranchOp.trueDestOperands()[index] |
| 334 | + : condBranchOp.falseDestOperands()[index]; |
| 335 | +} |
| 336 | + |
| 337 | +/// Connect the PHI nodes to the results of preceding blocks. |
| 338 | +template <typename T> |
| 339 | +static void |
| 340 | +connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping, |
| 341 | + const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) { |
| 342 | + // Skip the first block, it cannot be branched to and its arguments correspond |
| 343 | + // to the arguments of the LLVM function. |
| 344 | + for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { |
| 345 | + Block *bb = &*it; |
| 346 | + llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); |
| 347 | + auto phis = llvmBB->phis(); |
| 348 | + auto numArguments = bb->getNumArguments(); |
| 349 | + assert(numArguments == std::distance(phis.begin(), phis.end())); |
| 350 | + for (auto &numberedPhiNode : llvm::enumerate(phis)) { |
| 351 | + auto &phiNode = numberedPhiNode.value(); |
| 352 | + unsigned index = numberedPhiNode.index(); |
| 353 | + for (auto *pred : bb->getPredecessors()) { |
| 354 | + phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( |
| 355 | + bb, pred, numArguments, index)), |
| 356 | + blockMapping.lookup(pred)); |
| 357 | + } |
| 358 | + } |
| 359 | + } |
| 360 | +} |
| 361 | + |
| 362 | +// TODO: implement an iterative version |
| 363 | +static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) { |
| 364 | + blocks.insert(b); |
| 365 | + for (Block *bb : b->getSuccessors()) { |
| 366 | + if (blocks.count(bb) == 0) |
| 367 | + topologicalSortImpl(blocks, bb); |
| 368 | + } |
| 369 | +} |
| 370 | + |
| 371 | +/// Sort function blocks topologically. |
| 372 | +template <typename T> |
| 373 | +static llvm::SetVector<Block *> topologicalSort(T &f) { |
| 374 | + // For each blocks that has not been visited yet (i.e. that has no |
| 375 | + // predecessors), add it to the list and traverse its successors in DFS |
| 376 | + // preorder. |
| 377 | + llvm::SetVector<Block *> blocks; |
| 378 | + for (Block &b : f) { |
| 379 | + if (blocks.count(&b) == 0) |
| 380 | + topologicalSortImpl(blocks, &b); |
| 381 | + } |
| 382 | + assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted"); |
| 383 | + |
| 384 | + return blocks; |
| 385 | +} |
| 386 | + |
| 387 | +/// Convert the OpenMP parallel Operation to LLVM IR. |
| 388 | +LogicalResult |
| 389 | +ModuleTranslation::convertOmpParallel(Operation &opInst, |
| 390 | + llvm::IRBuilder<> &builder) { |
| 391 | + using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; |
| 392 | + |
| 393 | + auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, |
| 394 | + llvm::BasicBlock &continuationIP) { |
| 395 | + llvm::LLVMContext &llvmContext = llvmModule->getContext(); |
| 396 | + |
| 397 | + llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock(); |
| 398 | + llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator(); |
| 399 | + |
| 400 | + builder.SetInsertPoint(codeGenIPBB); |
| 401 | + |
| 402 | + for (auto ®ion : opInst.getRegions()) { |
| 403 | + for (auto &bb : region) { |
| 404 | + auto *llvmBB = llvm::BasicBlock::Create( |
| 405 | + llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent()); |
| 406 | + blockMapping[&bb] = llvmBB; |
| 407 | + } |
| 408 | + |
| 409 | + // Then, convert blocks one by one in topological order to ensure |
| 410 | + // defs are converted before uses. |
| 411 | + llvm::SetVector<Block *> blocks = topologicalSort(region); |
| 412 | + for (auto indexedBB : llvm::enumerate(blocks)) { |
| 413 | + Block *bb = indexedBB.value(); |
| 414 | + llvm::BasicBlock *curLLVMBB = blockMapping[bb]; |
| 415 | + if (bb->isEntryBlock()) |
| 416 | + codeGenIPBBTI->setSuccessor(0, curLLVMBB); |
| 417 | + |
| 418 | + // TODO: Error not returned up the hierarchy |
| 419 | + if (failed( |
| 420 | + convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))) |
| 421 | + return; |
| 422 | + |
| 423 | + // If this block has the terminator then add a jump to |
| 424 | + // continuation bb |
| 425 | + for (auto &op : *bb) { |
| 426 | + if (isa<omp::TerminatorOp>(op)) { |
| 427 | + builder.SetInsertPoint(curLLVMBB); |
| 428 | + builder.CreateBr(&continuationIP); |
| 429 | + } |
| 430 | + } |
| 431 | + } |
| 432 | + // Finally, after all blocks have been traversed and values mapped, |
| 433 | + // connect the PHI nodes to the results of preceding blocks. |
| 434 | + connectPHINodes(region, valueMapping, blockMapping); |
| 435 | + } |
| 436 | + }; |
| 437 | + |
| 438 | + // TODO: Perform appropriate actions according to the data-sharing |
| 439 | + // attribute (shared, private, firstprivate, ...) of variables. |
| 440 | + // Currently defaults to shared. |
| 441 | + auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP, |
| 442 | + llvm::Value &vPtr, |
| 443 | + llvm::Value *&replacementValue) -> InsertPointTy { |
| 444 | + replacementValue = &vPtr; |
| 445 | + |
| 446 | + return codeGenIP; |
| 447 | + }; |
| 448 | + |
| 449 | + // TODO: Perform finalization actions for variables. This has to be |
| 450 | + // called for variables which have destructors/finalizers. |
| 451 | + auto finiCB = [&](InsertPointTy codeGenIP) {}; |
| 452 | + |
| 453 | + // TODO: The various operands of parallel operation are not handled. |
| 454 | + // Parallel operation is created with some default options for now. |
| 455 | + llvm::Value *ifCond = nullptr; |
| 456 | + llvm::Value *numThreads = nullptr; |
| 457 | + bool isCancellable = false; |
| 458 | + builder.restoreIP(ompBuilder->CreateParallel( |
| 459 | + builder, bodyGenCB, privCB, finiCB, ifCond, numThreads, |
| 460 | + llvm::omp::OMP_PROC_BIND_default, isCancellable)); |
| 461 | + return success(); |
| 462 | +} |
308 | 463 |
|
309 | 464 | /// Given an OpenMP MLIR operation, create the corresponding LLVM IR
|
310 | 465 | /// (including OpenMP runtime calls).
|
@@ -340,6 +495,9 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
|
340 | 495 | ompBuilder->CreateFlush(builder.saveIP());
|
341 | 496 | return success();
|
342 | 497 | })
|
| 498 | + .Case([&](omp::TerminatorOp) { return success(); }) |
| 499 | + .Case( |
| 500 | + [&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); }) |
343 | 501 | .Default([&](Operation *inst) {
|
344 | 502 | return inst->emitError("unsupported OpenMP operation: ")
|
345 | 503 | << inst->getName();
|
@@ -556,75 +714,6 @@ LogicalResult ModuleTranslation::convertGlobals() {
|
556 | 714 | return success();
|
557 | 715 | }
|
558 | 716 |
|
559 |
| -/// Get the SSA value passed to the current block from the terminator operation |
560 |
| -/// of its predecessor. |
561 |
| -static Value getPHISourceValue(Block *current, Block *pred, |
562 |
| - unsigned numArguments, unsigned index) { |
563 |
| - auto &terminator = *pred->getTerminator(); |
564 |
| - if (isa<LLVM::BrOp>(terminator)) { |
565 |
| - return terminator.getOperand(index); |
566 |
| - } |
567 |
| - |
568 |
| - // For conditional branches, we need to check if the current block is reached |
569 |
| - // through the "true" or the "false" branch and take the relevant operands. |
570 |
| - auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator); |
571 |
| - assert(condBranchOp && |
572 |
| - "only branch operations can be terminators of a block that " |
573 |
| - "has successors"); |
574 |
| - assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) && |
575 |
| - "successors with arguments in LLVM conditional branches must be " |
576 |
| - "different blocks"); |
577 |
| - |
578 |
| - return condBranchOp.getSuccessor(0) == current |
579 |
| - ? condBranchOp.trueDestOperands()[index] |
580 |
| - : condBranchOp.falseDestOperands()[index]; |
581 |
| -} |
582 |
| - |
583 |
| -void ModuleTranslation::connectPHINodes(LLVMFuncOp func) { |
584 |
| - // Skip the first block, it cannot be branched to and its arguments correspond |
585 |
| - // to the arguments of the LLVM function. |
586 |
| - for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) { |
587 |
| - Block *bb = &*it; |
588 |
| - llvm::BasicBlock *llvmBB = blockMapping.lookup(bb); |
589 |
| - auto phis = llvmBB->phis(); |
590 |
| - auto numArguments = bb->getNumArguments(); |
591 |
| - assert(numArguments == std::distance(phis.begin(), phis.end())); |
592 |
| - for (auto &numberedPhiNode : llvm::enumerate(phis)) { |
593 |
| - auto &phiNode = numberedPhiNode.value(); |
594 |
| - unsigned index = numberedPhiNode.index(); |
595 |
| - for (auto *pred : bb->getPredecessors()) { |
596 |
| - phiNode.addIncoming(valueMapping.lookup(getPHISourceValue( |
597 |
| - bb, pred, numArguments, index)), |
598 |
| - blockMapping.lookup(pred)); |
599 |
| - } |
600 |
| - } |
601 |
| - } |
602 |
| -} |
603 |
| - |
604 |
| -// TODO: implement an iterative version |
605 |
| -static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) { |
606 |
| - blocks.insert(b); |
607 |
| - for (Block *bb : b->getSuccessors()) { |
608 |
| - if (blocks.count(bb) == 0) |
609 |
| - topologicalSortImpl(blocks, bb); |
610 |
| - } |
611 |
| -} |
612 |
| - |
613 |
| -/// Sort function blocks topologically. |
614 |
| -static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) { |
615 |
| - // For each blocks that has not been visited yet (i.e. that has no |
616 |
| - // predecessors), add it to the list and traverse its successors in DFS |
617 |
| - // preorder. |
618 |
| - llvm::SetVector<Block *> blocks; |
619 |
| - for (Block &b : f) { |
620 |
| - if (blocks.count(&b) == 0) |
621 |
| - topologicalSortImpl(blocks, &b); |
622 |
| - } |
623 |
| - assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted"); |
624 |
| - |
625 |
| - return blocks; |
626 |
| -} |
627 |
| - |
628 | 717 | /// Attempts to add an attribute identified by `key`, optionally with the given
|
629 | 718 | /// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
|
630 | 719 | /// attribute has a kind known to LLVM IR, create the attribute of this kind,
|
@@ -772,7 +861,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
|
772 | 861 |
|
773 | 862 | // Finally, after all blocks have been traversed and values mapped, connect
|
774 | 863 | // the PHI nodes to the results of preceding blocks.
|
775 |
| - connectPHINodes(func); |
| 864 | + connectPHINodes(func, valueMapping, blockMapping); |
776 | 865 | return success();
|
777 | 866 | }
|
778 | 867 |
|
|
0 commit comments