Skip to content

Commit d9067dc

Browse files
Lowering of OpenMP Parallel operation to LLVM IR 1/n
This patch introduces lowering of the OpenMP parallel operation to LLVM IR using the OpenMPIRBuilder. Functions topologicalSort and connectPhiNodes are generalised so that they work with operations also. connectPhiNodes is also made static. Lowering works for a parallel region with multiple blocks. Clauses and arguments of the OpenMP operation are not handled. Reviewed By: rriddle, anchu-rajendran Differential Revision: https://reviews.llvm.org/D81660
1 parent 004bf35 commit d9067dc

File tree

4 files changed

+210
-75
lines changed

4 files changed

+210
-75
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def OpenMP_Dialect : Dialect {
2424
class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
2525
Op<OpenMP_Dialect, mnemonic, traits>;
2626

27-
2827
//===----------------------------------------------------------------------===//
2928
// 2.6 parallel Construct
3029
//===----------------------------------------------------------------------===//
@@ -81,8 +80,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
8180
of the parallel region.
8281
}];
8382

84-
let arguments = (ins Optional<I1>:$if_expr_var,
85-
Optional<AnyInteger>:$num_threads_var,
83+
let arguments = (ins Optional<AnyType>:$if_expr_var,
84+
Optional<AnyType>:$num_threads_var,
8685
OptionalAttr<ClauseDefault>:$default_val,
8786
Variadic<AnyType>:$private_vars,
8887
Variadic<AnyType>:$firstprivate_vars,

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class ModuleTranslation {
8787
llvm::IRBuilder<> &builder);
8888
virtual LogicalResult convertOmpOperation(Operation &op,
8989
llvm::IRBuilder<> &builder);
90+
virtual LogicalResult convertOmpParallel(Operation &op,
91+
llvm::IRBuilder<> &builder);
9092
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
9193

9294
/// A helper to look up remapped operands in the value remapping table.
@@ -100,7 +102,6 @@ class ModuleTranslation {
100102
LogicalResult convertFunctions();
101103
LogicalResult convertGlobals();
102104
LogicalResult convertOneFunction(LLVMFuncOp func);
103-
void connectPHINodes(LLVMFuncOp func);
104105
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
105106

106107
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 160 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
#include "llvm/ADT/SetVector.h"
2626
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
2727
#include "llvm/IR/BasicBlock.h"
28+
#include "llvm/IR/CFG.h"
2829
#include "llvm/IR/Constants.h"
2930
#include "llvm/IR/DerivedTypes.h"
3031
#include "llvm/IR/IRBuilder.h"
3132
#include "llvm/IR/LLVMContext.h"
3233
#include "llvm/IR/Module.h"
34+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
3335
#include "llvm/Transforms/Utils/Cloning.h"
3436

3537
using namespace mlir;
@@ -304,7 +306,160 @@ ModuleTranslation::ModuleTranslation(Operation *module,
304306
assert(satisfiesLLVMModule(mlirModule) &&
305307
"mlirModule should honor LLVM's module semantics.");
306308
}
307-
ModuleTranslation::~ModuleTranslation() {}
309+
ModuleTranslation::~ModuleTranslation() {
310+
if (ompBuilder)
311+
ompBuilder->finalize();
312+
}
313+
314+
/// Get the SSA value passed to the current block from the terminator operation
315+
/// of its predecessor.
316+
static Value getPHISourceValue(Block *current, Block *pred,
317+
unsigned numArguments, unsigned index) {
318+
Operation &terminator = *pred->getTerminator();
319+
if (isa<LLVM::BrOp>(terminator))
320+
return terminator.getOperand(index);
321+
322+
// For conditional branches, we need to check if the current block is reached
323+
// through the "true" or the "false" branch and take the relevant operands.
324+
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
325+
assert(condBranchOp &&
326+
"only branch operations can be terminators of a block that "
327+
"has successors");
328+
assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
329+
"successors with arguments in LLVM conditional branches must be "
330+
"different blocks");
331+
332+
return condBranchOp.getSuccessor(0) == current
333+
? condBranchOp.trueDestOperands()[index]
334+
: condBranchOp.falseDestOperands()[index];
335+
}
336+
337+
/// Connect the PHI nodes to the results of preceding blocks.
338+
template <typename T>
339+
static void
340+
connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
341+
const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) {
342+
// Skip the first block, it cannot be branched to and its arguments correspond
343+
// to the arguments of the LLVM function.
344+
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
345+
Block *bb = &*it;
346+
llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
347+
auto phis = llvmBB->phis();
348+
auto numArguments = bb->getNumArguments();
349+
assert(numArguments == std::distance(phis.begin(), phis.end()));
350+
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
351+
auto &phiNode = numberedPhiNode.value();
352+
unsigned index = numberedPhiNode.index();
353+
for (auto *pred : bb->getPredecessors()) {
354+
phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
355+
bb, pred, numArguments, index)),
356+
blockMapping.lookup(pred));
357+
}
358+
}
359+
}
360+
}
361+
362+
// TODO: implement an iterative version
363+
static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
364+
blocks.insert(b);
365+
for (Block *bb : b->getSuccessors()) {
366+
if (blocks.count(bb) == 0)
367+
topologicalSortImpl(blocks, bb);
368+
}
369+
}
370+
371+
/// Sort function blocks topologically.
372+
template <typename T>
373+
static llvm::SetVector<Block *> topologicalSort(T &f) {
374+
// For each blocks that has not been visited yet (i.e. that has no
375+
// predecessors), add it to the list and traverse its successors in DFS
376+
// preorder.
377+
llvm::SetVector<Block *> blocks;
378+
for (Block &b : f) {
379+
if (blocks.count(&b) == 0)
380+
topologicalSortImpl(blocks, &b);
381+
}
382+
assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
383+
384+
return blocks;
385+
}
386+
387+
/// Convert the OpenMP parallel Operation to LLVM IR.
388+
LogicalResult
389+
ModuleTranslation::convertOmpParallel(Operation &opInst,
390+
llvm::IRBuilder<> &builder) {
391+
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
392+
393+
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
394+
llvm::BasicBlock &continuationIP) {
395+
llvm::LLVMContext &llvmContext = llvmModule->getContext();
396+
397+
llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
398+
llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
399+
400+
builder.SetInsertPoint(codeGenIPBB);
401+
402+
for (auto &region : opInst.getRegions()) {
403+
for (auto &bb : region) {
404+
auto *llvmBB = llvm::BasicBlock::Create(
405+
llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent());
406+
blockMapping[&bb] = llvmBB;
407+
}
408+
409+
// Then, convert blocks one by one in topological order to ensure
410+
// defs are converted before uses.
411+
llvm::SetVector<Block *> blocks = topologicalSort(region);
412+
for (auto indexedBB : llvm::enumerate(blocks)) {
413+
Block *bb = indexedBB.value();
414+
llvm::BasicBlock *curLLVMBB = blockMapping[bb];
415+
if (bb->isEntryBlock())
416+
codeGenIPBBTI->setSuccessor(0, curLLVMBB);
417+
418+
// TODO: Error not returned up the hierarchy
419+
if (failed(
420+
convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
421+
return;
422+
423+
// If this block has the terminator then add a jump to
424+
// continuation bb
425+
for (auto &op : *bb) {
426+
if (isa<omp::TerminatorOp>(op)) {
427+
builder.SetInsertPoint(curLLVMBB);
428+
builder.CreateBr(&continuationIP);
429+
}
430+
}
431+
}
432+
// Finally, after all blocks have been traversed and values mapped,
433+
// connect the PHI nodes to the results of preceding blocks.
434+
connectPHINodes(region, valueMapping, blockMapping);
435+
}
436+
};
437+
438+
// TODO: Perform appropriate actions according to the data-sharing
439+
// attribute (shared, private, firstprivate, ...) of variables.
440+
// Currently defaults to shared.
441+
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
442+
llvm::Value &vPtr,
443+
llvm::Value *&replacementValue) -> InsertPointTy {
444+
replacementValue = &vPtr;
445+
446+
return codeGenIP;
447+
};
448+
449+
// TODO: Perform finalization actions for variables. This has to be
450+
// called for variables which have destructors/finalizers.
451+
auto finiCB = [&](InsertPointTy codeGenIP) {};
452+
453+
// TODO: The various operands of parallel operation are not handled.
454+
// Parallel operation is created with some default options for now.
455+
llvm::Value *ifCond = nullptr;
456+
llvm::Value *numThreads = nullptr;
457+
bool isCancellable = false;
458+
builder.restoreIP(ompBuilder->CreateParallel(
459+
builder, bodyGenCB, privCB, finiCB, ifCond, numThreads,
460+
llvm::omp::OMP_PROC_BIND_default, isCancellable));
461+
return success();
462+
}
308463

309464
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
310465
/// (including OpenMP runtime calls).
@@ -340,6 +495,9 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
340495
ompBuilder->CreateFlush(builder.saveIP());
341496
return success();
342497
})
498+
.Case([&](omp::TerminatorOp) { return success(); })
499+
.Case(
500+
[&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
343501
.Default([&](Operation *inst) {
344502
return inst->emitError("unsupported OpenMP operation: ")
345503
<< inst->getName();
@@ -556,75 +714,6 @@ LogicalResult ModuleTranslation::convertGlobals() {
556714
return success();
557715
}
558716

559-
/// Get the SSA value passed to the current block from the terminator operation
560-
/// of its predecessor.
561-
static Value getPHISourceValue(Block *current, Block *pred,
562-
unsigned numArguments, unsigned index) {
563-
auto &terminator = *pred->getTerminator();
564-
if (isa<LLVM::BrOp>(terminator)) {
565-
return terminator.getOperand(index);
566-
}
567-
568-
// For conditional branches, we need to check if the current block is reached
569-
// through the "true" or the "false" branch and take the relevant operands.
570-
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
571-
assert(condBranchOp &&
572-
"only branch operations can be terminators of a block that "
573-
"has successors");
574-
assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
575-
"successors with arguments in LLVM conditional branches must be "
576-
"different blocks");
577-
578-
return condBranchOp.getSuccessor(0) == current
579-
? condBranchOp.trueDestOperands()[index]
580-
: condBranchOp.falseDestOperands()[index];
581-
}
582-
583-
void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
584-
// Skip the first block, it cannot be branched to and its arguments correspond
585-
// to the arguments of the LLVM function.
586-
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
587-
Block *bb = &*it;
588-
llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
589-
auto phis = llvmBB->phis();
590-
auto numArguments = bb->getNumArguments();
591-
assert(numArguments == std::distance(phis.begin(), phis.end()));
592-
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
593-
auto &phiNode = numberedPhiNode.value();
594-
unsigned index = numberedPhiNode.index();
595-
for (auto *pred : bb->getPredecessors()) {
596-
phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
597-
bb, pred, numArguments, index)),
598-
blockMapping.lookup(pred));
599-
}
600-
}
601-
}
602-
}
603-
604-
// TODO: implement an iterative version
605-
static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
606-
blocks.insert(b);
607-
for (Block *bb : b->getSuccessors()) {
608-
if (blocks.count(bb) == 0)
609-
topologicalSortImpl(blocks, bb);
610-
}
611-
}
612-
613-
/// Sort function blocks topologically.
614-
static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
615-
// For each blocks that has not been visited yet (i.e. that has no
616-
// predecessors), add it to the list and traverse its successors in DFS
617-
// preorder.
618-
llvm::SetVector<Block *> blocks;
619-
for (Block &b : f) {
620-
if (blocks.count(&b) == 0)
621-
topologicalSortImpl(blocks, &b);
622-
}
623-
assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
624-
625-
return blocks;
626-
}
627-
628717
/// Attempts to add an attribute identified by `key`, optionally with the given
629718
/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
630719
/// attribute has a kind known to LLVM IR, create the attribute of this kind,
@@ -772,7 +861,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
772861

773862
// Finally, after all blocks have been traversed and values mapped, connect
774863
// the PHI nodes to the results of preceding blocks.
775-
connectPHINodes(func);
864+
connectPHINodes(func, valueMapping, blockMapping);
776865
return success();
777866
}
778867

mlir/test/Target/openmp-llvm.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,49 @@ llvm.func @test_flush_construct(%arg0: !llvm.i32) {
3232
// CHECK-NEXT: ret void
3333
llvm.return
3434
}
35+
36+
// CHECK-LABEL: define void @test_omp_parallel_1()
37+
llvm.func @test_omp_parallel_1() -> () {
38+
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_1:.*]] to {{.*}}
39+
omp.parallel {
40+
omp.barrier
41+
omp.terminator
42+
}
43+
44+
llvm.return
45+
}
46+
47+
// CHECK: define internal void @[[OMP_OUTLINED_FN_1]]
48+
// CHECK: call void @__kmpc_barrier
49+
50+
llvm.func @body(!llvm.i64)
51+
52+
// CHECK-LABEL: define void @test_omp_parallel_2()
53+
llvm.func @test_omp_parallel_2() -> () {
54+
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_2:.*]] to {{.*}}
55+
omp.parallel {
56+
^bb0:
57+
%0 = llvm.mlir.constant(1 : index) : !llvm.i64
58+
%1 = llvm.mlir.constant(42 : index) : !llvm.i64
59+
llvm.call @body(%0) : (!llvm.i64) -> ()
60+
llvm.call @body(%1) : (!llvm.i64) -> ()
61+
llvm.br ^bb1
62+
63+
^bb1:
64+
%2 = llvm.add %0, %1 : !llvm.i64
65+
llvm.call @body(%2) : (!llvm.i64) -> ()
66+
omp.terminator
67+
}
68+
llvm.return
69+
}
70+
71+
// CHECK: define internal void @[[OMP_OUTLINED_FN_2]]
72+
// CHECK-LABEL: omp.par.region:
73+
// CHECK: br label %omp.par.region1
74+
// CHECK-LABEL: omp.par.region1:
75+
// CHECK: call void @body(i64 1)
76+
// CHECK: call void @body(i64 42)
77+
// CHECK: br label %omp.par.region2
78+
// CHECK-LABEL: omp.par.region2:
79+
// CHECK: call void @body(i64 43)
80+
// CHECK: br label %omp.par.pre_finalize

0 commit comments

Comments
 (0)