Skip to content

Commit 0afaa2c

Browse files
committed
[mlir][OpenMP] cancel(lation point) taskgroup LLVMIR
A cancel or cancellation point for taskgroup is always nested inside of a task inside of the taskgroup. For the task which is cancelled, it is that task which needs to be cleaned up: not the owning taskgroup. Therefore the cancellation branch handler is done in the conversion of the task not in conversion of taskgroup. I added a firstprivate clause to the test for cancel taskgroup to demonstrate that the block being branched to is the same block where mandatory cleanup code is added. Cancellation point follows exactly the same code path.
1 parent 8338a3c commit 0afaa2c

File tree

5 files changed

+144
-86
lines changed

5 files changed

+144
-86
lines changed

flang/docs/OpenMPSupport.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
5151
| depend clause | P | depend clause with array sections are not supported |
5252
| declare reduction construct | N | |
5353
| atomic construct extensions | Y | |
54-
| cancel construct | N | |
55-
| cancellation point construct | N | |
54+
| cancel construct | Y | |
55+
| cancellation point construct | Y | |
5656
| parallel do simd construct | P | linear clause is not supported |
5757
| target teams construct | P | device and reduction clauses are not supported |
5858
| teams distribute construct | P | reduction and dist_schedule clauses not supported |

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158158
if (op.getBare())
159159
result = todo("ompx_bare");
160160
};
161-
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162-
omp::ClauseCancellationConstructType cancelledDirective =
163-
op.getCancelDirective();
164-
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
165-
result = todo("cancel directive construct type not yet supported");
166-
};
167161
auto checkDepend = [&todo](auto op, LogicalResult &result) {
168162
if (!op.getDependVars().empty() || op.getDependKinds())
169163
result = todo("depend");
@@ -240,10 +234,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
240234

241235
LogicalResult result = success();
242236
llvm::TypeSwitch<Operation &>(op)
243-
.Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
244-
.Case([&](omp::CancellationPointOp op) {
245-
checkCancelDirective(op, result);
246-
})
247237
.Case([&](omp::DistributeOp op) {
248238
checkAllocate(op, result);
249239
checkDistSchedule(op, result);
@@ -1889,6 +1879,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
18891879
}
18901880
}
18911881

1882+
/// Shared implementation of a callback which adds a termiator for the new block
1883+
/// created for the branch taken when an openmp construct is cancelled. The
1884+
/// terminator is saved in \p cancelTerminators. This callback is invoked only
1885+
/// if there is cancellation inside of the taskgroup body.
1886+
/// The terminator will need to be fixed to branch to the correct block to
1887+
/// cleanup the construct.
1888+
static void
1889+
pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
1890+
llvm::IRBuilderBase &llvmBuilder,
1891+
llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
1892+
llvm::omp::Directive cancelDirective) {
1893+
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1894+
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
1895+
1896+
// ip is currently in the block branched to if cancellation occured.
1897+
// We need to create a branch to terminate that block.
1898+
llvmBuilder.restoreIP(ip);
1899+
1900+
// We must still clean up the construct after cancelling it, so we need to
1901+
// branch to the block that finalizes the taskgroup.
1902+
// That block has not been created yet so use this block as a dummy for now
1903+
// and fix this after creating the operation.
1904+
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
1905+
return llvm::Error::success();
1906+
};
1907+
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
1908+
// created in case the body contains omp.cancel (which will then expect to be
1909+
// able to find this cleanup callback).
1910+
ompBuilder.pushFinalizationCB(
1911+
{finiCB, cancelDirective, constructIsCancellable(op)});
1912+
}
1913+
1914+
/// If we cancelled the construct, we should branch to the finalization block of
1915+
/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
1916+
/// is immediately before the continuation block. Now this finalization has
1917+
/// been created we can fix the branch.
1918+
static void
1919+
popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
1920+
llvm::OpenMPIRBuilder &ompBuilder,
1921+
const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1922+
ompBuilder.popFinalizationCB();
1923+
llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
1924+
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1925+
assert(cancelBranch->getNumSuccessors() == 1 &&
1926+
"cancel branch should have one target");
1927+
cancelBranch->setSuccessor(0, constructFini);
1928+
}
1929+
}
1930+
18921931
namespace {
18931932
/// TaskContextStructManager takes care of creating and freeing a structure
18941933
/// containing information needed by the task body to execute.
@@ -2202,6 +2241,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22022241
return llvm::Error::success();
22032242
};
22042243

2244+
llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2245+
SmallVector<llvm::BranchInst *> cancelTerminators;
2246+
// The directive to match here is OMPD_taskgroup because it is the taskgroup
2247+
// which is canceled. This is handled here because it is the task's cleanup
2248+
// block which should be branched to.
2249+
pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2250+
llvm::omp::Directive::OMPD_taskgroup);
2251+
22052252
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
22062253
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
22072254
moduleTranslation, dds);
@@ -2219,6 +2266,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22192266
if (failed(handleError(afterIP, *taskOp)))
22202267
return failure();
22212268

2269+
// Set the correct branch target for task cancellation
2270+
popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2271+
22222272
builder.restoreIP(*afterIP);
22232273
return success();
22242274
}
@@ -2349,28 +2399,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23492399
: llvm::omp::WorksharingLoopType::ForStaticLoop;
23502400

23512401
SmallVector<llvm::BranchInst *> cancelTerminators;
2352-
// This callback is invoked only if there is cancellation inside of the wsloop
2353-
// body.
2354-
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2355-
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
2356-
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2357-
2358-
// ip is currently in the block branched to if cancellation occured.
2359-
// We need to create a branch to terminate that block.
2360-
llvmBuilder.restoreIP(ip);
2361-
2362-
// We must still clean up the wsloop after cancelling it, so we need to
2363-
// branch to the block that finalizes the wsloop.
2364-
// That block has not been created yet so use this block as a dummy for now
2365-
// and fix this after creating the wsloop.
2366-
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2367-
return llvm::Error::success();
2368-
};
2369-
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
2370-
// created in case the body contains omp.cancel (which will then expect to be
2371-
// able to find this cleanup callback).
2372-
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
2373-
constructIsCancellable(wsloopOp)});
2402+
pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
2403+
llvm::omp::Directive::OMPD_for);
23742404

23752405
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
23762406
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
@@ -2393,18 +2423,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23932423
if (failed(handleError(wsloopIP, opInst)))
23942424
return failure();
23952425

2396-
ompBuilder->popFinalizationCB();
2397-
if (!cancelTerminators.empty()) {
2398-
// If we cancelled the loop, we should branch to the finalization block of
2399-
// the wsloop (which is always immediately before the loop continuation
2400-
// block). Now the finalization has been created, we can fix the branch.
2401-
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
2402-
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2403-
assert(cancelBranch->getNumSuccessors() == 1 &&
2404-
"cancel branch should have one target");
2405-
cancelBranch->setSuccessor(0, wsloopFini);
2406-
}
2407-
}
2426+
// Set the correct branch target for task cancellation
2427+
popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
24082428

24092429
// Process the reductions if required.
24102430
if (failed(createReductionsAndCleanup(
@@ -3063,9 +3083,6 @@ convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
30633083
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
30643084
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
30653085

3066-
if (failed(checkImplementationStatus(*op.getOperation())))
3067-
return failure();
3068-
30693086
llvm::Value *ifCond = nullptr;
30703087
if (Value ifVar = op.getIfExpr())
30713088
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -3091,9 +3108,6 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
30913108
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
30923109
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
30933110

3094-
if (failed(checkImplementationStatus(*op.getOperation())))
3095-
return failure();
3096-
30973111
llvm::omp::Directive cancelledDirective =
30983112
convertCancellationConstructType(op.getCancelDirective());
30993113

mlir/test/Target/LLVMIR/openmp-cancel.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,51 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
243243
// CHECK: ret void
244244
// CHECK: .cncl: ; preds = %[[VAL_44]]
245245
// CHECK: br label %[[VAL_38]]
246+
247+
omp.private {type = firstprivate} @i32_priv : i32 copy {
248+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
249+
%0 = llvm.load %arg0 : !llvm.ptr -> i32
250+
llvm.store %0, %arg1 : i32, !llvm.ptr
251+
omp.yield(%arg1 : !llvm.ptr)
252+
}
253+
254+
llvm.func @do_something(!llvm.ptr)
255+
256+
llvm.func @cancel_taskgroup(%arg0: !llvm.ptr) {
257+
omp.taskgroup {
258+
// Using firstprivate clause so we have some end of task cleanup to branch to
259+
// after the cancellation.
260+
omp.task private(@i32_priv %arg0 -> %arg1 : !llvm.ptr) {
261+
omp.cancel cancellation_construct_type(taskgroup)
262+
llvm.call @do_something(%arg1) : (!llvm.ptr) -> ()
263+
omp.terminator
264+
}
265+
omp.terminator
266+
}
267+
llvm.return
268+
}
269+
// CHECK-LABEL: define internal void @cancel_taskgroup..omp_par(
270+
// CHECK: task.alloca:
271+
// CHECK: %[[VAL_21:.*]] = load ptr, ptr %[[VAL_22:.*]], align 8
272+
// CHECK: %[[VAL_23:.*]] = getelementptr { ptr }, ptr %[[VAL_21]], i32 0, i32 0
273+
// CHECK: %[[VAL_24:.*]] = load ptr, ptr %[[VAL_23]], align 8, !align !1
274+
// CHECK: br label %[[VAL_25:.*]]
275+
// CHECK: task.body: ; preds = %[[VAL_26:.*]]
276+
// CHECK: %[[VAL_27:.*]] = getelementptr { i32 }, ptr %[[VAL_24]], i32 0, i32 0
277+
// CHECK: br label %[[VAL_28:.*]]
278+
// CHECK: omp.task.region: ; preds = %[[VAL_25]]
279+
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
280+
// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 4)
281+
// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
282+
// CHECK: br i1 %[[VAL_31]], label %omp.task.region.split, label %omp.task.region.cncl
283+
// CHECK: omp.task.region.cncl:
284+
// CHECK: br label %omp.region.cont2
285+
// CHECK: omp.region.cont2:
286+
// Both cancellation and normal paths reach the end-of-task cleanup:
287+
// CHECK: tail call void @free(ptr %[[VAL_24]])
288+
// CHECK: br label %task.exit.exitStub
289+
// CHECK: omp.task.region.split:
290+
// CHECK: call void @do_something(ptr %[[VAL_27]])
291+
// CHECK: br label %omp.region.cont2
292+
// CHECK: task.exit.exitStub:
293+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,33 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
186186
// CHECK: ret void
187187
// CHECK: omp.loop_nest.region.cncl: ; preds = %[[VAL_100]]
188188
// CHECK: br label %[[VAL_96]]
189+
190+
191+
llvm.func @cancellation_point_taskgroup() {
192+
omp.taskgroup {
193+
omp.task {
194+
omp.cancellation_point cancellation_construct_type(taskgroup)
195+
omp.terminator
196+
}
197+
omp.terminator
198+
}
199+
llvm.return
200+
}
201+
// CHECK-LABEL: define internal void @cancellation_point_taskgroup..omp_par(
202+
// CHECK: task.alloca:
203+
// CHECK: br label %[[VAL_50:.*]]
204+
// CHECK: task.body: ; preds = %[[VAL_51:.*]]
205+
// CHECK: br label %[[VAL_52:.*]]
206+
// CHECK: omp.task.region: ; preds = %[[VAL_50]]
207+
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
208+
// CHECK: %[[VAL_54:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_53]], i32 4)
209+
// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_54]], 0
210+
// CHECK: br i1 %[[VAL_55]], label %omp.task.region.split, label %omp.task.region.cncl
211+
// CHECK: omp.task.region.cncl:
212+
// CHECK: br label %omp.region.cont1
213+
// CHECK: omp.region.cont1:
214+
// CHECK: br label %task.exit.exitStub
215+
// CHECK: omp.task.region.split:
216+
// CHECK: br label %omp.region.cont1
217+
// CHECK: task.exit.exitStub:
218+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
2626

2727
// -----
2828

29-
llvm.func @cancel_taskgroup() {
30-
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
31-
omp.taskgroup {
32-
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
33-
omp.task {
34-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
35-
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
36-
omp.cancel cancellation_construct_type(taskgroup)
37-
omp.terminator
38-
}
39-
omp.terminator
40-
}
41-
llvm.return
42-
}
43-
44-
// -----
45-
46-
llvm.func @cancellation_point_taskgroup() {
47-
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
48-
omp.taskgroup {
49-
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
50-
omp.task {
51-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancellation_point operation}}
52-
// expected-error@below {{LLVM Translation failed for operation: omp.cancellation_point}}
53-
omp.cancellation_point cancellation_construct_type(taskgroup)
54-
omp.terminator
55-
}
56-
omp.terminator
57-
}
58-
llvm.return
59-
}
60-
61-
// -----
62-
6329
llvm.func @do_simd(%lb : i32, %ub : i32, %step : i32) {
6430
omp.wsloop {
6531
// expected-warning@below {{simd information on composite construct discarded}}

0 commit comments

Comments
 (0)