Skip to content

Commit e1e20bb

Browse files
author
git apple-llvm automerger
committed
Merge commit '9fc58cc39077' from apple/master into swift/master-next
2 parents 71502de + 9fc58cc commit e1e20bb

File tree

4 files changed

+210
-75
lines changed

4 files changed

+210
-75
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ def OpenMP_Dialect : Dialect {
2424
class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
2525
Op<OpenMP_Dialect, mnemonic, traits>;
2626

27-
2827
//===----------------------------------------------------------------------===//
2928
// 2.6 parallel Construct
3029
//===----------------------------------------------------------------------===//
@@ -81,8 +80,8 @@ def ParallelOp : OpenMP_Op<"parallel", [AttrSizedOperandSegments]> {
8180
of the parallel region.
8281
}];
8382

84-
let arguments = (ins Optional<I1>:$if_expr_var,
85-
Optional<AnyInteger>:$num_threads_var,
83+
let arguments = (ins Optional<AnyType>:$if_expr_var,
84+
Optional<AnyType>:$num_threads_var,
8685
OptionalAttr<ClauseDefault>:$default_val,
8786
Variadic<AnyType>:$private_vars,
8887
Variadic<AnyType>:$firstprivate_vars,

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class ModuleTranslation {
8787
llvm::IRBuilder<> &builder);
8888
virtual LogicalResult convertOmpOperation(Operation &op,
8989
llvm::IRBuilder<> &builder);
90+
virtual LogicalResult convertOmpParallel(Operation &op,
91+
llvm::IRBuilder<> &builder);
9092
static std::unique_ptr<llvm::Module> prepareLLVMModule(Operation *m);
9193

9294
/// A helper to look up remapped operands in the value remapping table.
@@ -100,7 +102,6 @@ class ModuleTranslation {
100102
LogicalResult convertFunctions();
101103
LogicalResult convertGlobals();
102104
LogicalResult convertOneFunction(LLVMFuncOp func);
103-
void connectPHINodes(LLVMFuncOp func);
104105
LogicalResult convertBlock(Block &bb, bool ignoreArguments);
105106

106107
llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 160 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,13 @@
2525
#include "llvm/ADT/SetVector.h"
2626
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
2727
#include "llvm/IR/BasicBlock.h"
28+
#include "llvm/IR/CFG.h"
2829
#include "llvm/IR/Constants.h"
2930
#include "llvm/IR/DerivedTypes.h"
3031
#include "llvm/IR/IRBuilder.h"
3132
#include "llvm/IR/LLVMContext.h"
3233
#include "llvm/IR/Module.h"
34+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
3335
#include "llvm/Transforms/Utils/Cloning.h"
3436

3537
using namespace mlir;
@@ -304,7 +306,160 @@ ModuleTranslation::ModuleTranslation(Operation *module,
304306
assert(satisfiesLLVMModule(mlirModule) &&
305307
"mlirModule should honor LLVM's module semantics.");
306308
}
307-
ModuleTranslation::~ModuleTranslation() {}
309+
ModuleTranslation::~ModuleTranslation() {
310+
if (ompBuilder)
311+
ompBuilder->finalize();
312+
}
313+
314+
/// Get the SSA value passed to the current block from the terminator operation
315+
/// of its predecessor.
316+
static Value getPHISourceValue(Block *current, Block *pred,
317+
unsigned numArguments, unsigned index) {
318+
Operation &terminator = *pred->getTerminator();
319+
if (isa<LLVM::BrOp>(terminator))
320+
return terminator.getOperand(index);
321+
322+
// For conditional branches, we need to check if the current block is reached
323+
// through the "true" or the "false" branch and take the relevant operands.
324+
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
325+
assert(condBranchOp &&
326+
"only branch operations can be terminators of a block that "
327+
"has successors");
328+
assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
329+
"successors with arguments in LLVM conditional branches must be "
330+
"different blocks");
331+
332+
return condBranchOp.getSuccessor(0) == current
333+
? condBranchOp.trueDestOperands()[index]
334+
: condBranchOp.falseDestOperands()[index];
335+
}
336+
337+
/// Connect the PHI nodes to the results of preceding blocks.
338+
template <typename T>
339+
static void
340+
connectPHINodes(T &func, const DenseMap<Value, llvm::Value *> &valueMapping,
341+
const DenseMap<Block *, llvm::BasicBlock *> &blockMapping) {
342+
// Skip the first block, it cannot be branched to and its arguments correspond
343+
// to the arguments of the LLVM function.
344+
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
345+
Block *bb = &*it;
346+
llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
347+
auto phis = llvmBB->phis();
348+
auto numArguments = bb->getNumArguments();
349+
assert(numArguments == std::distance(phis.begin(), phis.end()));
350+
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
351+
auto &phiNode = numberedPhiNode.value();
352+
unsigned index = numberedPhiNode.index();
353+
for (auto *pred : bb->getPredecessors()) {
354+
phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
355+
bb, pred, numArguments, index)),
356+
blockMapping.lookup(pred));
357+
}
358+
}
359+
}
360+
}
361+
362+
// TODO: implement an iterative version
363+
static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
364+
blocks.insert(b);
365+
for (Block *bb : b->getSuccessors()) {
366+
if (blocks.count(bb) == 0)
367+
topologicalSortImpl(blocks, bb);
368+
}
369+
}
370+
371+
/// Sort function blocks topologically.
372+
template <typename T>
373+
static llvm::SetVector<Block *> topologicalSort(T &f) {
374+
// For each blocks that has not been visited yet (i.e. that has no
375+
// predecessors), add it to the list and traverse its successors in DFS
376+
// preorder.
377+
llvm::SetVector<Block *> blocks;
378+
for (Block &b : f) {
379+
if (blocks.count(&b) == 0)
380+
topologicalSortImpl(blocks, &b);
381+
}
382+
assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
383+
384+
return blocks;
385+
}
386+
387+
/// Convert the OpenMP parallel Operation to LLVM IR.
388+
LogicalResult
389+
ModuleTranslation::convertOmpParallel(Operation &opInst,
390+
llvm::IRBuilder<> &builder) {
391+
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
392+
393+
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
394+
llvm::BasicBlock &continuationIP) {
395+
llvm::LLVMContext &llvmContext = llvmModule->getContext();
396+
397+
llvm::BasicBlock *codeGenIPBB = codeGenIP.getBlock();
398+
llvm::Instruction *codeGenIPBBTI = codeGenIPBB->getTerminator();
399+
400+
builder.SetInsertPoint(codeGenIPBB);
401+
402+
for (auto &region : opInst.getRegions()) {
403+
for (auto &bb : region) {
404+
auto *llvmBB = llvm::BasicBlock::Create(
405+
llvmContext, "omp.par.region", codeGenIP.getBlock()->getParent());
406+
blockMapping[&bb] = llvmBB;
407+
}
408+
409+
// Then, convert blocks one by one in topological order to ensure
410+
// defs are converted before uses.
411+
llvm::SetVector<Block *> blocks = topologicalSort(region);
412+
for (auto indexedBB : llvm::enumerate(blocks)) {
413+
Block *bb = indexedBB.value();
414+
llvm::BasicBlock *curLLVMBB = blockMapping[bb];
415+
if (bb->isEntryBlock())
416+
codeGenIPBBTI->setSuccessor(0, curLLVMBB);
417+
418+
// TODO: Error not returned up the hierarchy
419+
if (failed(
420+
convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0)))
421+
return;
422+
423+
// If this block has the terminator then add a jump to
424+
// continuation bb
425+
for (auto &op : *bb) {
426+
if (isa<omp::TerminatorOp>(op)) {
427+
builder.SetInsertPoint(curLLVMBB);
428+
builder.CreateBr(&continuationIP);
429+
}
430+
}
431+
}
432+
// Finally, after all blocks have been traversed and values mapped,
433+
// connect the PHI nodes to the results of preceding blocks.
434+
connectPHINodes(region, valueMapping, blockMapping);
435+
}
436+
};
437+
438+
// TODO: Perform appropriate actions according to the data-sharing
439+
// attribute (shared, private, firstprivate, ...) of variables.
440+
// Currently defaults to shared.
441+
auto privCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP,
442+
llvm::Value &vPtr,
443+
llvm::Value *&replacementValue) -> InsertPointTy {
444+
replacementValue = &vPtr;
445+
446+
return codeGenIP;
447+
};
448+
449+
// TODO: Perform finalization actions for variables. This has to be
450+
// called for variables which have destructors/finalizers.
451+
auto finiCB = [&](InsertPointTy codeGenIP) {};
452+
453+
// TODO: The various operands of parallel operation are not handled.
454+
// Parallel operation is created with some default options for now.
455+
llvm::Value *ifCond = nullptr;
456+
llvm::Value *numThreads = nullptr;
457+
bool isCancellable = false;
458+
builder.restoreIP(ompBuilder->CreateParallel(
459+
builder, bodyGenCB, privCB, finiCB, ifCond, numThreads,
460+
llvm::omp::OMP_PROC_BIND_default, isCancellable));
461+
return success();
462+
}
308463

309464
/// Given an OpenMP MLIR operation, create the corresponding LLVM IR
310465
/// (including OpenMP runtime calls).
@@ -340,6 +495,9 @@ ModuleTranslation::convertOmpOperation(Operation &opInst,
340495
ompBuilder->CreateFlush(builder.saveIP());
341496
return success();
342497
})
498+
.Case([&](omp::TerminatorOp) { return success(); })
499+
.Case(
500+
[&](omp::ParallelOp) { return convertOmpParallel(opInst, builder); })
343501
.Default([&](Operation *inst) {
344502
return inst->emitError("unsupported OpenMP operation: ")
345503
<< inst->getName();
@@ -556,75 +714,6 @@ LogicalResult ModuleTranslation::convertGlobals() {
556714
return success();
557715
}
558716

559-
/// Get the SSA value passed to the current block from the terminator operation
560-
/// of its predecessor.
561-
static Value getPHISourceValue(Block *current, Block *pred,
562-
unsigned numArguments, unsigned index) {
563-
auto &terminator = *pred->getTerminator();
564-
if (isa<LLVM::BrOp>(terminator)) {
565-
return terminator.getOperand(index);
566-
}
567-
568-
// For conditional branches, we need to check if the current block is reached
569-
// through the "true" or the "false" branch and take the relevant operands.
570-
auto condBranchOp = dyn_cast<LLVM::CondBrOp>(terminator);
571-
assert(condBranchOp &&
572-
"only branch operations can be terminators of a block that "
573-
"has successors");
574-
assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
575-
"successors with arguments in LLVM conditional branches must be "
576-
"different blocks");
577-
578-
return condBranchOp.getSuccessor(0) == current
579-
? condBranchOp.trueDestOperands()[index]
580-
: condBranchOp.falseDestOperands()[index];
581-
}
582-
583-
void ModuleTranslation::connectPHINodes(LLVMFuncOp func) {
584-
// Skip the first block, it cannot be branched to and its arguments correspond
585-
// to the arguments of the LLVM function.
586-
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
587-
Block *bb = &*it;
588-
llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
589-
auto phis = llvmBB->phis();
590-
auto numArguments = bb->getNumArguments();
591-
assert(numArguments == std::distance(phis.begin(), phis.end()));
592-
for (auto &numberedPhiNode : llvm::enumerate(phis)) {
593-
auto &phiNode = numberedPhiNode.value();
594-
unsigned index = numberedPhiNode.index();
595-
for (auto *pred : bb->getPredecessors()) {
596-
phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
597-
bb, pred, numArguments, index)),
598-
blockMapping.lookup(pred));
599-
}
600-
}
601-
}
602-
}
603-
604-
// TODO: implement an iterative version
605-
static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
606-
blocks.insert(b);
607-
for (Block *bb : b->getSuccessors()) {
608-
if (blocks.count(bb) == 0)
609-
topologicalSortImpl(blocks, bb);
610-
}
611-
}
612-
613-
/// Sort function blocks topologically.
614-
static llvm::SetVector<Block *> topologicalSort(LLVMFuncOp f) {
615-
// For each blocks that has not been visited yet (i.e. that has no
616-
// predecessors), add it to the list and traverse its successors in DFS
617-
// preorder.
618-
llvm::SetVector<Block *> blocks;
619-
for (Block &b : f) {
620-
if (blocks.count(&b) == 0)
621-
topologicalSortImpl(blocks, &b);
622-
}
623-
assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
624-
625-
return blocks;
626-
}
627-
628717
/// Attempts to add an attribute identified by `key`, optionally with the given
629718
/// `value` to LLVM function `llvmFunc`. Reports errors at `loc` if any. If the
630719
/// attribute has a kind known to LLVM IR, create the attribute of this kind,
@@ -772,7 +861,7 @@ LogicalResult ModuleTranslation::convertOneFunction(LLVMFuncOp func) {
772861

773862
// Finally, after all blocks have been traversed and values mapped, connect
774863
// the PHI nodes to the results of preceding blocks.
775-
connectPHINodes(func);
864+
connectPHINodes(func, valueMapping, blockMapping);
776865
return success();
777866
}
778867

mlir/test/Target/openmp-llvm.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,49 @@ llvm.func @test_flush_construct(%arg0: !llvm.i32) {
3232
// CHECK-NEXT: ret void
3333
llvm.return
3434
}
35+
36+
// CHECK-LABEL: define void @test_omp_parallel_1()
37+
llvm.func @test_omp_parallel_1() -> () {
38+
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_1:.*]] to {{.*}}
39+
omp.parallel {
40+
omp.barrier
41+
omp.terminator
42+
}
43+
44+
llvm.return
45+
}
46+
47+
// CHECK: define internal void @[[OMP_OUTLINED_FN_1]]
48+
// CHECK: call void @__kmpc_barrier
49+
50+
llvm.func @body(!llvm.i64)
51+
52+
// CHECK-LABEL: define void @test_omp_parallel_2()
53+
llvm.func @test_omp_parallel_2() -> () {
54+
// CHECK: call void{{.*}}@__kmpc_fork_call{{.*}}@[[OMP_OUTLINED_FN_2:.*]] to {{.*}}
55+
omp.parallel {
56+
^bb0:
57+
%0 = llvm.mlir.constant(1 : index) : !llvm.i64
58+
%1 = llvm.mlir.constant(42 : index) : !llvm.i64
59+
llvm.call @body(%0) : (!llvm.i64) -> ()
60+
llvm.call @body(%1) : (!llvm.i64) -> ()
61+
llvm.br ^bb1
62+
63+
^bb1:
64+
%2 = llvm.add %0, %1 : !llvm.i64
65+
llvm.call @body(%2) : (!llvm.i64) -> ()
66+
omp.terminator
67+
}
68+
llvm.return
69+
}
70+
71+
// CHECK: define internal void @[[OMP_OUTLINED_FN_2]]
72+
// CHECK-LABEL: omp.par.region:
73+
// CHECK: br label %omp.par.region1
74+
// CHECK-LABEL: omp.par.region1:
75+
// CHECK: call void @body(i64 1)
76+
// CHECK: call void @body(i64 42)
77+
// CHECK: br label %omp.par.region2
78+
// CHECK-LABEL: omp.par.region2:
79+
// CHECK: call void @body(i64 43)
80+
// CHECK: br label %omp.par.pre_finalize

0 commit comments

Comments
 (0)