Skip to content

Commit 5e1a9a0

Browse files
committed
[mlir][OpenMP] cancel(lation point) taskgroup LLVMIR
A cancel or cancellation point for taskgroup is always nested inside of a task inside of the taskgroup. For the task which is cancelled, it is that task which needs to be cleaned up: not the owning taskgroup. Therefore the cancellation branch handler is done in the conversion of the task not in conversion of taskgroup. I added a firstprivate clause to the test for cancel taskgroup to demonstrate that the block being branched to is the same block where mandatory cleanup code is added. Cancellation point follows exactly the same code path.
1 parent f401175 commit 5e1a9a0

File tree

5 files changed

+144
-86
lines changed

5 files changed

+144
-86
lines changed

flang/docs/OpenMPSupport.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
5151
| depend clause | P | depend clause with array sections are not supported |
5252
| declare reduction construct | N | |
5353
| atomic construct extensions | Y | |
54-
| cancel construct | N | |
55-
| cancellation point construct | N | |
54+
| cancel construct | Y | |
55+
| cancellation point construct | Y | |
5656
| parallel do simd construct | P | linear clause is not supported |
5757
| target teams construct | P | device and reduction clauses are not supported |
5858
| teams distribute construct | P | reduction and dist_schedule clauses not supported |

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158158
if (op.getBare())
159159
result = todo("ompx_bare");
160160
};
161-
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162-
omp::ClauseCancellationConstructType cancelledDirective =
163-
op.getCancelDirective();
164-
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
165-
result = todo("cancel directive construct type not yet supported");
166-
};
167161
auto checkDepend = [&todo](auto op, LogicalResult &result) {
168162
if (!op.getDependVars().empty() || op.getDependKinds())
169163
result = todo("depend");
@@ -254,10 +248,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
254248

255249
LogicalResult result = success();
256250
llvm::TypeSwitch<Operation &>(op)
257-
.Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
258-
.Case([&](omp::CancellationPointOp op) {
259-
checkCancelDirective(op, result);
260-
})
261251
.Case([&](omp::DistributeOp op) {
262252
checkAllocate(op, result);
263253
checkDistSchedule(op, result);
@@ -1902,6 +1892,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
19021892
}
19031893
}
19041894

1895+
/// Shared implementation of a callback which adds a termiator for the new block
1896+
/// created for the branch taken when an openmp construct is cancelled. The
1897+
/// terminator is saved in \p cancelTerminators. This callback is invoked only
1898+
/// if there is cancellation inside of the taskgroup body.
1899+
/// The terminator will need to be fixed to branch to the correct block to
1900+
/// cleanup the construct.
1901+
static void
1902+
pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
1903+
llvm::IRBuilderBase &llvmBuilder,
1904+
llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
1905+
llvm::omp::Directive cancelDirective) {
1906+
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1907+
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
1908+
1909+
// ip is currently in the block branched to if cancellation occured.
1910+
// We need to create a branch to terminate that block.
1911+
llvmBuilder.restoreIP(ip);
1912+
1913+
// We must still clean up the construct after cancelling it, so we need to
1914+
// branch to the block that finalizes the taskgroup.
1915+
// That block has not been created yet so use this block as a dummy for now
1916+
// and fix this after creating the operation.
1917+
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
1918+
return llvm::Error::success();
1919+
};
1920+
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
1921+
// created in case the body contains omp.cancel (which will then expect to be
1922+
// able to find this cleanup callback).
1923+
ompBuilder.pushFinalizationCB(
1924+
{finiCB, cancelDirective, constructIsCancellable(op)});
1925+
}
1926+
1927+
/// If we cancelled the construct, we should branch to the finalization block of
1928+
/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
1929+
/// is immediately before the continuation block. Now this finalization has
1930+
/// been created we can fix the branch.
1931+
static void
1932+
popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
1933+
llvm::OpenMPIRBuilder &ompBuilder,
1934+
const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1935+
ompBuilder.popFinalizationCB();
1936+
llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
1937+
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1938+
assert(cancelBranch->getNumSuccessors() == 1 &&
1939+
"cancel branch should have one target");
1940+
cancelBranch->setSuccessor(0, constructFini);
1941+
}
1942+
}
1943+
19051944
namespace {
19061945
/// TaskContextStructManager takes care of creating and freeing a structure
19071946
/// containing information needed by the task body to execute.
@@ -2215,6 +2254,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22152254
return llvm::Error::success();
22162255
};
22172256

2257+
llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
2258+
SmallVector<llvm::BranchInst *> cancelTerminators;
2259+
// The directive to match here is OMPD_taskgroup because it is the taskgroup
2260+
// which is canceled. This is handled here because it is the task's cleanup
2261+
// block which should be branched to.
2262+
pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
2263+
llvm::omp::Directive::OMPD_taskgroup);
2264+
22182265
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
22192266
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
22202267
moduleTranslation, dds);
@@ -2232,6 +2279,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
22322279
if (failed(handleError(afterIP, *taskOp)))
22332280
return failure();
22342281

2282+
// Set the correct branch target for task cancellation
2283+
popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
2284+
22352285
builder.restoreIP(*afterIP);
22362286
return success();
22372287
}
@@ -2362,28 +2412,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
23622412
: llvm::omp::WorksharingLoopType::ForStaticLoop;
23632413

23642414
SmallVector<llvm::BranchInst *> cancelTerminators;
2365-
// This callback is invoked only if there is cancellation inside of the wsloop
2366-
// body.
2367-
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2368-
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
2369-
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
2370-
2371-
// ip is currently in the block branched to if cancellation occured.
2372-
// We need to create a branch to terminate that block.
2373-
llvmBuilder.restoreIP(ip);
2374-
2375-
// We must still clean up the wsloop after cancelling it, so we need to
2376-
// branch to the block that finalizes the wsloop.
2377-
// That block has not been created yet so use this block as a dummy for now
2378-
// and fix this after creating the wsloop.
2379-
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
2380-
return llvm::Error::success();
2381-
};
2382-
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
2383-
// created in case the body contains omp.cancel (which will then expect to be
2384-
// able to find this cleanup callback).
2385-
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
2386-
constructIsCancellable(wsloopOp)});
2415+
pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
2416+
llvm::omp::Directive::OMPD_for);
23872417

23882418
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
23892419
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
@@ -2406,18 +2436,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
24062436
if (failed(handleError(wsloopIP, opInst)))
24072437
return failure();
24082438

2409-
ompBuilder->popFinalizationCB();
2410-
if (!cancelTerminators.empty()) {
2411-
// If we cancelled the loop, we should branch to the finalization block of
2412-
// the wsloop (which is always immediately before the loop continuation
2413-
// block). Now the finalization has been created, we can fix the branch.
2414-
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
2415-
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2416-
assert(cancelBranch->getNumSuccessors() == 1 &&
2417-
"cancel branch should have one target");
2418-
cancelBranch->setSuccessor(0, wsloopFini);
2419-
}
2420-
}
2439+
// Set the correct branch target for task cancellation
2440+
popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
24212441

24222442
// Process the reductions if required.
24232443
if (failed(createReductionsAndCleanup(
@@ -3075,9 +3095,6 @@ convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
30753095
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
30763096
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
30773097

3078-
if (failed(checkImplementationStatus(*op.getOperation())))
3079-
return failure();
3080-
30813098
llvm::Value *ifCond = nullptr;
30823099
if (Value ifVar = op.getIfExpr())
30833100
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -3103,9 +3120,6 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
31033120
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
31043121
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
31053122

3106-
if (failed(checkImplementationStatus(*op.getOperation())))
3107-
return failure();
3108-
31093123
llvm::omp::Directive cancelledDirective =
31103124
convertCancellationConstructType(op.getCancelDirective());
31113125

mlir/test/Target/LLVMIR/openmp-cancel.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,3 +243,51 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
243243
// CHECK: ret void
244244
// CHECK: .cncl: ; preds = %[[VAL_44]]
245245
// CHECK: br label %[[VAL_38]]
246+
247+
omp.private {type = firstprivate} @i32_priv : i32 copy {
248+
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
249+
%0 = llvm.load %arg0 : !llvm.ptr -> i32
250+
llvm.store %0, %arg1 : i32, !llvm.ptr
251+
omp.yield(%arg1 : !llvm.ptr)
252+
}
253+
254+
llvm.func @do_something(!llvm.ptr)
255+
256+
llvm.func @cancel_taskgroup(%arg0: !llvm.ptr) {
257+
omp.taskgroup {
258+
// Using firstprivate clause so we have some end of task cleanup to branch to
259+
// after the cancellation.
260+
omp.task private(@i32_priv %arg0 -> %arg1 : !llvm.ptr) {
261+
omp.cancel cancellation_construct_type(taskgroup)
262+
llvm.call @do_something(%arg1) : (!llvm.ptr) -> ()
263+
omp.terminator
264+
}
265+
omp.terminator
266+
}
267+
llvm.return
268+
}
269+
// CHECK-LABEL: define internal void @cancel_taskgroup..omp_par(
270+
// CHECK: task.alloca:
271+
// CHECK: %[[VAL_21:.*]] = load ptr, ptr %[[VAL_22:.*]], align 8
272+
// CHECK: %[[VAL_23:.*]] = getelementptr { ptr }, ptr %[[VAL_21]], i32 0, i32 0
273+
// CHECK: %[[VAL_24:.*]] = load ptr, ptr %[[VAL_23]], align 8, !align !1
274+
// CHECK: br label %[[VAL_25:.*]]
275+
// CHECK: task.body: ; preds = %[[VAL_26:.*]]
276+
// CHECK: %[[VAL_27:.*]] = getelementptr { i32 }, ptr %[[VAL_24]], i32 0, i32 0
277+
// CHECK: br label %[[VAL_28:.*]]
278+
// CHECK: omp.task.region: ; preds = %[[VAL_25]]
279+
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
280+
// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 4)
281+
// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
282+
// CHECK: br i1 %[[VAL_31]], label %omp.task.region.split, label %omp.task.region.cncl
283+
// CHECK: omp.task.region.cncl:
284+
// CHECK: br label %omp.region.cont2
285+
// CHECK: omp.region.cont2:
286+
// Both cancellation and normal paths reach the end-of-task cleanup:
287+
// CHECK: tail call void @free(ptr %[[VAL_24]])
288+
// CHECK: br label %task.exit.exitStub
289+
// CHECK: omp.task.region.split:
290+
// CHECK: call void @do_something(ptr %[[VAL_27]])
291+
// CHECK: br label %omp.region.cont2
292+
// CHECK: task.exit.exitStub:
293+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,33 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
186186
// CHECK: ret void
187187
// CHECK: omp.loop_nest.region.cncl: ; preds = %[[VAL_100]]
188188
// CHECK: br label %[[VAL_96]]
189+
190+
191+
llvm.func @cancellation_point_taskgroup() {
192+
omp.taskgroup {
193+
omp.task {
194+
omp.cancellation_point cancellation_construct_type(taskgroup)
195+
omp.terminator
196+
}
197+
omp.terminator
198+
}
199+
llvm.return
200+
}
201+
// CHECK-LABEL: define internal void @cancellation_point_taskgroup..omp_par(
202+
// CHECK: task.alloca:
203+
// CHECK: br label %[[VAL_50:.*]]
204+
// CHECK: task.body: ; preds = %[[VAL_51:.*]]
205+
// CHECK: br label %[[VAL_52:.*]]
206+
// CHECK: omp.task.region: ; preds = %[[VAL_50]]
207+
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
208+
// CHECK: %[[VAL_54:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_53]], i32 4)
209+
// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_54]], 0
210+
// CHECK: br i1 %[[VAL_55]], label %omp.task.region.split, label %omp.task.region.cncl
211+
// CHECK: omp.task.region.cncl:
212+
// CHECK: br label %omp.region.cont1
213+
// CHECK: omp.region.cont1:
214+
// CHECK: br label %task.exit.exitStub
215+
// CHECK: omp.task.region.split:
216+
// CHECK: br label %omp.region.cont1
217+
// CHECK: task.exit.exitStub:
218+
// CHECK: ret void

mlir/test/Target/LLVMIR/openmp-todo.mlir

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -26,40 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
2626

2727
// -----
2828

29-
llvm.func @cancel_taskgroup() {
30-
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
31-
omp.taskgroup {
32-
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
33-
omp.task {
34-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
35-
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
36-
omp.cancel cancellation_construct_type(taskgroup)
37-
omp.terminator
38-
}
39-
omp.terminator
40-
}
41-
llvm.return
42-
}
43-
44-
// -----
45-
46-
llvm.func @cancellation_point_taskgroup() {
47-
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
48-
omp.taskgroup {
49-
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
50-
omp.task {
51-
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancellation_point operation}}
52-
// expected-error@below {{LLVM Translation failed for operation: omp.cancellation_point}}
53-
omp.cancellation_point cancellation_construct_type(taskgroup)
54-
omp.terminator
55-
}
56-
omp.terminator
57-
}
58-
llvm.return
59-
}
60-
61-
// -----
62-
6329
llvm.func @do_simd(%lb : i32, %ub : i32, %step : i32) {
6430
omp.wsloop {
6531
// expected-warning@below {{simd information on composite construct discarded}}

0 commit comments

Comments
 (0)