Skip to content

Commit e402009

Browse files
authored
[mlir][OpenMP] cancel(lation point) taskgroup LLVMIR (#137841)
A cancel or cancellation point for taskgroup is always nested inside of a task inside of the taskgroup. For the task which is cancelled, it is that task which needs to be cleaned up: not the owning taskgroup. Therefore the cancellation branch handler is done in the conversion of the task not in conversion of taskgroup. I added a firstprivate clause to the test for cancel taskgroup to demonstrate that the block being branched to is the same block where mandatory cleanup code is added. Cancellation point follows exactly the same code path.
1 parent ce7c196 commit e402009

File tree

5 files changed

+162
-78
lines changed

5 files changed

+162
-78
lines changed

flang/docs/OpenMPSupport.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
5151
| depend clause | P | depend clause with array sections are not supported |
5252
| declare reduction construct | N | |
5353
| atomic construct extensions | Y | |
54-
| cancel construct | N | |
55-
| cancellation point construct | N | |
54+
| cancel construct | Y | |
55+
| cancellation point construct | Y | |
5656
| parallel do simd construct | P | linear clause is not supported |
5757
| target teams construct | P | device and reduction clauses are not supported |
5858
| teams distribute construct | P | reduction and dist_schedule clauses not supported |

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 82 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,18 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161161
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162162
omp::ClauseCancellationConstructType cancelledDirective =
163163
op.getCancelDirective();
164-
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
165-
result = todo("cancel directive construct type not yet supported");
164+
// Cancelling a taskloop is not yet supported because we don't yet have LLVM
165+
// IR conversion for taskloop
166+
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
167+
Operation *parent = op->getParentOp();
168+
while (parent) {
169+
if (parent->getDialect() == op->getDialect())
170+
break;
171+
parent = parent->getParentOp();
172+
}
173+
if (isa_and_nonnull<omp::TaskloopOp>(parent))
174+
result = todo("cancel directive inside of taskloop");
175+
}
166176
};
167177
auto checkDepend = [&todo](auto op, LogicalResult &result) {
168178
if (!op.getDependVars().empty() || op.getDependKinds())
@@ -1889,6 +1899,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
18891899
}
18901900
}
18911901

1902+
/// Shared implementation of a callback which adds a termiator for the new block
1903+
/// created for the branch taken when an openmp construct is cancelled. The
1904+
/// terminator is saved in \p cancelTerminators. This callback is invoked only
1905+
/// if there is cancellation inside of the taskgroup body.
1906+
/// The terminator will need to be fixed to branch to the correct block to
1907+
/// cleanup the construct.
1908+
static void
1909+
pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
1910+
llvm::IRBuilderBase &llvmBuilder,
1911+
llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
1912+
llvm::omp::Directive cancelDirective) {
1913+
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1914+
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
1915+
1916+
// ip is currently in the block branched to if cancellation occured.
1917+
// We need to create a branch to terminate that block.
1918+
llvmBuilder.restoreIP(ip);
1919+
1920+
// We must still clean up the construct after cancelling it, so we need to
1921+
// branch to the block that finalizes the taskgroup.
1922+
// That block has not been created yet so use this block as a dummy for now
1923+
// and fix this after creating the operation.
1924+
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
1925+
return llvm::Error::success();
1926+
};
1927+
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
1928+
// created in case the body contains omp.cancel (which will then expect to be
1929+
// able to find this cleanup callback).
1930+
ompBuilder.pushFinalizationCB(
1931+
{finiCB, cancelDirective, constructIsCancellable(op)});
1932+
}
1933+
1934+
/// If we cancelled the construct, we should branch to the finalization block of
1935+
/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
1936+
/// is immediately before the continuation block. Now this finalization has
1937+
/// been created we can fix the branch.
1938+
static void
1939+
popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
1940+
llvm::OpenMPIRBuilder &ompBuilder,
1941+
const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1942+
ompBuilder.popFinalizationCB();
1943+
llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
1944+
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1945+
assert(cancelBranch->getNumSuccessors() == 1 &&
1946+
"cancel branch should have one target");
1947+
cancelBranch->setSuccessor(0, constructFini);
1948+
}
1949+
}
1950+
18921951
namespace {
18931952
/// TaskContextStructManager takes care of creating and freeing a structure
18941953
/// containing information needed by the task body to execute.
@@ -2202,6 +2261,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22022261
return llvm::Error::success();
22032262
};
22042263

2264+
llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2265+
SmallVector<llvm::BranchInst *> cancelTerminators;
2266+
// The directive to match here is OMPD_taskgroup because it is the taskgroup
2267+
// which is canceled. This is handled here because it is the task's cleanup
2268+
// block which should be branched to.
2269+
pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2270+
llvm::omp::Directive::OMPD_taskgroup);
2271+
22052272
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
22062273
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
22072274
moduleTranslation, dds);
@@ -2219,6 +2286,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22192286
if (failed(handleError(afterIP, *taskOp)))
22202287
return failure();
22212288

2289+
// Set the correct branch target for task cancellation
2290+
popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2291+
22222292
builder.restoreIP(*afterIP);
22232293
return success();
22242294
}
@@ -2349,28 +2419,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23492419
: llvm::omp::WorksharingLoopType::ForStaticLoop;
23502420

23512421
SmallVector<llvm::BranchInst *> cancelTerminators;
2352-
// This callback is invoked only if there is cancellation inside of the wsloop
2353-
// body.
2354-
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2355-
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
2356-
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2357-
2358-
// ip is currently in the block branched to if cancellation occured.
2359-
// We need to create a branch to terminate that block.
2360-
llvmBuilder.restoreIP(ip);
2361-
2362-
// We must still clean up the wsloop after cancelling it, so we need to
2363-
// branch to the block that finalizes the wsloop.
2364-
// That block has not been created yet so use this block as a dummy for now
2365-
// and fix this after creating the wsloop.
2366-
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2367-
return llvm::Error::success();
2368-
};
2369-
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
2370-
// created in case the body contains omp.cancel (which will then expect to be
2371-
// able to find this cleanup callback).
2372-
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
2373-
constructIsCancellable(wsloopOp)});
2422+
pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
2423+
llvm::omp::Directive::OMPD_for);
23742424

23752425
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
23762426
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
@@ -2393,18 +2443,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23932443
if (failed(handleError(wsloopIP, opInst)))
23942444
return failure();
23952445

2396-
ompBuilder->popFinalizationCB();
2397-
if (!cancelTerminators.empty()) {
2398-
// If we cancelled the loop, we should branch to the finalization block of
2399-
// the wsloop (which is always immediately before the loop continuation
2400-
// block). Now the finalization has been created, we can fix the branch.
2401-
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
2402-
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2403-
assert(cancelBranch->getNumSuccessors() == 1 &&
2404-
"cancel branch should have one target");
2405-
cancelBranch->setSuccessor(0, wsloopFini);
2406-
}
2407-
}
2446+
// Set the correct branch target for task cancellation
2447+
popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
24082448

24092449
// Process the reductions if required.
24102450
if (failed(createReductionsAndCleanup(
@@ -3060,12 +3100,12 @@ static llvm::omp::Directive convertCancellationConstructType(
30603100
static LogicalResult
30613101
convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
30623102
LLVM::ModuleTranslation &moduleTranslation) {
3063-
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3064-
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3065-
30663103
if (failed(checkImplementationStatus(*op.getOperation())))
30673104
return failure();
30683105

3106+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3107+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3108+
30693109
llvm::Value *ifCond = nullptr;
30703110
if (Value ifVar = op.getIfExpr())
30713111
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -3088,12 +3128,12 @@ static LogicalResult
30883128
convertOmpCancellationPoint(omp::CancellationPointOp op,
30893129
llvm::IRBuilderBase &builder,
30903130
LLVM::ModuleTranslation &moduleTranslation) {
3091-
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3092-
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3093-
30943131
if (failed(checkImplementationStatus(*op.getOperation())))
30953132
return failure();
30963133

3134+
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
3135+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3136+
30973137
llvm::omp::Directive cancelledDirective =
30983138
convertCancellationConstructType(op.getCancelDirective());
30993139

mlir/test/Target/LLVMIR/openmp-cancel.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,51 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
243243
// CHECK: ret void
244244
// CHECK: .cncl: ; preds = %[[VAL_44]]
245245
// CHECK: br label %[[VAL_38]]
246+
247+
omp.private {type = firstprivate} @i32_priv : i32 copy {
248+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
249+
%0 = llvm.load %arg0 : !llvm.ptr -> i32
250+
llvm.store %0, %arg1 : i32, !llvm.ptr
251+
omp.yield(%arg1 : !llvm.ptr)
252+
}
253+
254+
llvm.func @do_something(!llvm.ptr)
255+
256+
llvm.func @cancel_taskgroup(%arg0: !llvm.ptr) {
257+
omp.taskgroup {
258+
// Using firstprivate clause so we have some end of task cleanup to branch to
259+
// after the cancellation.
260+
omp.task private(@i32_priv %arg0 -> %arg1 : !llvm.ptr) {
261+
omp.cancel cancellation_construct_type(taskgroup)
262+
llvm.call @do_something(%arg1) : (!llvm.ptr) -> ()
263+
omp.terminator
264+
}
265+
omp.terminator
266+
}
267+
llvm.return
268+
}
269+
// CHECK-LABEL: define internal void @cancel_taskgroup..omp_par(
270+
// CHECK: task.alloca:
271+
// CHECK: %[[VAL_21:.*]] = load ptr, ptr %[[VAL_22:.*]], align 8
272+
// CHECK: %[[VAL_23:.*]] = getelementptr { ptr }, ptr %[[VAL_21]], i32 0, i32 0
273+
// CHECK: %[[VAL_24:.*]] = load ptr, ptr %[[VAL_23]], align 8, !align !1
274+
// CHECK: br label %[[VAL_25:.*]]
275+
// CHECK: task.body: ; preds = %[[VAL_26:.*]]
276+
// CHECK: %[[VAL_27:.*]] = getelementptr { i32 }, ptr %[[VAL_24]], i32 0, i32 0
277+
// CHECK: br label %[[VAL_28:.*]]
278+
// CHECK: omp.task.region: ; preds = %[[VAL_25]]
279+
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
280+
// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 4)
281+
// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
282+
// CHECK: br i1 %[[VAL_31]], label %omp.task.region.split, label %omp.task.region.cncl
283+
// CHECK: omp.task.region.cncl:
284+
// CHECK: br label %omp.region.cont2
285+
// CHECK: omp.region.cont2:
286+
// Both cancellation and normal paths reach the end-of-task cleanup:
287+
// CHECK: tail call void @free(ptr %[[VAL_24]])
288+
// CHECK: br label %task.exit.exitStub
289+
// CHECK: omp.task.region.split:
290+
// CHECK: call void @do_something(ptr %[[VAL_27]])
291+
// CHECK: br label %omp.region.cont2
292+
// CHECK: task.exit.exitStub:
293+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,33 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
186186
// CHECK: ret void
187187
// CHECK: omp.loop_nest.region.cncl: ; preds = %[[VAL_100]]
188188
// CHECK: br label %[[VAL_96]]
189+
190+
191+
llvm.func @cancellation_point_taskgroup() {
192+
omp.taskgroup {
193+
omp.task {
194+
omp.cancellation_point cancellation_construct_type(taskgroup)
195+
omp.terminator
196+
}
197+
omp.terminator
198+
}
199+
llvm.return
200+
}
201+
// CHECK-LABEL: define internal void @cancellation_point_taskgroup..omp_par(
202+
// CHECK: task.alloca:
203+
// CHECK: br label %[[VAL_50:.*]]
204+
// CHECK: task.body: ; preds = %[[VAL_51:.*]]
205+
// CHECK: br label %[[VAL_52:.*]]
206+
// CHECK: omp.task.region: ; preds = %[[VAL_50]]
207+
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
208+
// CHECK: %[[VAL_54:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_53]], i32 4)
209+
// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_54]], 0
210+
// CHECK: br i1 %[[VAL_55]], label %omp.task.region.split, label %omp.task.region.cncl
211+
// CHECK: omp.task.region.cncl:
212+
// CHECK: br label %omp.region.cont1
213+
// CHECK: omp.region.cont1:
214+
// CHECK: br label %task.exit.exitStub
215+
// CHECK: omp.task.region.split:
216+
// CHECK: br label %omp.region.cont1
217+
// CHECK: task.exit.exitStub:
218+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
2626

2727
// -----
2828

29-
llvm.func @cancel_taskgroup() {
30-
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
31-
omp.taskgroup {
32-
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
33-
omp.task {
34-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
35-
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
36-
omp.cancel cancellation_construct_type(taskgroup)
37-
omp.terminator
38-
}
39-
omp.terminator
40-
}
41-
llvm.return
42-
}
43-
44-
// -----
45-
46-
llvm.func @cancellation_point_taskgroup() {
47-
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
48-
omp.taskgroup {
49-
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
50-
omp.task {
51-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancellation_point operation}}
52-
// expected-error@below {{LLVM Translation failed for operation: omp.cancellation_point}}
53-
omp.cancellation_point cancellation_construct_type(taskgroup)
54-
omp.terminator
55-
}
56-
omp.terminator
57-
}
58-
llvm.return
59-
}
60-
61-
// -----
62-
6329
llvm.func @do_simd(%lb : i32, %ub : i32, %step : i32) {
6430
omp.wsloop {
6531
// expected-warning@below {{simd information on composite construct discarded}}

0 commit comments

Comments
 (0)