@@ -158,12 +158,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158
158
if (op.getBare ())
159
159
result = todo (" ompx_bare" );
160
160
};
161
- auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162
- omp::ClauseCancellationConstructType cancelledDirective =
163
- op.getCancelDirective ();
164
- if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
165
- result = todo (" cancel directive construct type not yet supported" );
166
- };
167
161
auto checkDepend = [&todo](auto op, LogicalResult &result) {
168
162
if (!op.getDependVars ().empty () || op.getDependKinds ())
169
163
result = todo (" depend" );
@@ -254,10 +248,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
254
248
255
249
LogicalResult result = success ();
256
250
llvm::TypeSwitch<Operation &>(op)
257
- .Case ([&](omp::CancelOp op) { checkCancelDirective (op, result); })
258
- .Case ([&](omp::CancellationPointOp op) {
259
- checkCancelDirective (op, result);
260
- })
261
251
.Case ([&](omp::DistributeOp op) {
262
252
checkAllocate (op, result);
263
253
checkDistSchedule (op, result);
@@ -1902,6 +1892,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
1902
1892
}
1903
1893
}
1904
1894
1895
+ // / Shared implementation of a callback which adds a termiator for the new block
1896
+ // / created for the branch taken when an openmp construct is cancelled. The
1897
+ // / terminator is saved in \p cancelTerminators. This callback is invoked only
1898
+ // / if there is cancellation inside of the taskgroup body.
1899
+ // / The terminator will need to be fixed to branch to the correct block to
1900
+ // / cleanup the construct.
1901
+ static void
1902
+ pushCancelFinalizationCB (SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
1903
+ llvm::IRBuilderBase &llvmBuilder,
1904
+ llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
1905
+ llvm::omp::Directive cancelDirective) {
1906
+ auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1907
+ llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
1908
+
1909
+ // ip is currently in the block branched to if cancellation occured.
1910
+ // We need to create a branch to terminate that block.
1911
+ llvmBuilder.restoreIP (ip);
1912
+
1913
+ // We must still clean up the construct after cancelling it, so we need to
1914
+ // branch to the block that finalizes the taskgroup.
1915
+ // That block has not been created yet so use this block as a dummy for now
1916
+ // and fix this after creating the operation.
1917
+ cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
1918
+ return llvm::Error::success ();
1919
+ };
1920
+ // We have to add the cleanup to the OpenMPIRBuilder before the body gets
1921
+ // created in case the body contains omp.cancel (which will then expect to be
1922
+ // able to find this cleanup callback).
1923
+ ompBuilder.pushFinalizationCB (
1924
+ {finiCB, cancelDirective, constructIsCancellable (op)});
1925
+ }
1926
+
1927
+ // / If we cancelled the construct, we should branch to the finalization block of
1928
+ // / that construct. OMPIRBuilder structures the CFG such that the cleanup block
1929
+ // / is immediately before the continuation block. Now this finalization has
1930
+ // / been created we can fix the branch.
1931
+ static void
1932
+ popCancelFinalizationCB (const ArrayRef<llvm::BranchInst *> cancelTerminators,
1933
+ llvm::OpenMPIRBuilder &ompBuilder,
1934
+ const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1935
+ ompBuilder.popFinalizationCB ();
1936
+ llvm::BasicBlock *constructFini = afterIP.getBlock ()->getSinglePredecessor ();
1937
+ for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1938
+ assert (cancelBranch->getNumSuccessors () == 1 &&
1939
+ " cancel branch should have one target" );
1940
+ cancelBranch->setSuccessor (0 , constructFini);
1941
+ }
1942
+ }
1943
+
1905
1944
namespace {
1906
1945
// / TaskContextStructManager takes care of creating and freeing a structure
1907
1946
// / containing information needed by the task body to execute.
@@ -2215,6 +2254,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2215
2254
return llvm::Error::success ();
2216
2255
};
2217
2256
2257
+ llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder ();
2258
+ SmallVector<llvm::BranchInst *> cancelTerminators;
2259
+ // The directive to match here is OMPD_taskgroup because it is the taskgroup
2260
+ // which is canceled. This is handled here because it is the task's cleanup
2261
+ // block which should be branched to.
2262
+ pushCancelFinalizationCB (cancelTerminators, builder, ompBuilder, taskOp,
2263
+ llvm::omp::Directive::OMPD_taskgroup);
2264
+
2218
2265
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
2219
2266
buildDependData (taskOp.getDependKinds (), taskOp.getDependVars (),
2220
2267
moduleTranslation, dds);
@@ -2232,6 +2279,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2232
2279
if (failed (handleError (afterIP, *taskOp)))
2233
2280
return failure ();
2234
2281
2282
+ // Set the correct branch target for task cancellation
2283
+ popCancelFinalizationCB (cancelTerminators, ompBuilder, afterIP.get ());
2284
+
2235
2285
builder.restoreIP (*afterIP);
2236
2286
return success ();
2237
2287
}
@@ -2362,28 +2412,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2362
2412
: llvm::omp::WorksharingLoopType::ForStaticLoop;
2363
2413
2364
2414
SmallVector<llvm::BranchInst *> cancelTerminators;
2365
- // This callback is invoked only if there is cancellation inside of the wsloop
2366
- // body.
2367
- auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2368
- llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder ;
2369
- llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
2370
-
2371
- // ip is currently in the block branched to if cancellation occured.
2372
- // We need to create a branch to terminate that block.
2373
- llvmBuilder.restoreIP (ip);
2374
-
2375
- // We must still clean up the wsloop after cancelling it, so we need to
2376
- // branch to the block that finalizes the wsloop.
2377
- // That block has not been created yet so use this block as a dummy for now
2378
- // and fix this after creating the wsloop.
2379
- cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
2380
- return llvm::Error::success ();
2381
- };
2382
- // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2383
- // created in case the body contains omp.cancel (which will then expect to be
2384
- // able to find this cleanup callback).
2385
- ompBuilder->pushFinalizationCB ({finiCB, llvm::omp::Directive::OMPD_for,
2386
- constructIsCancellable (wsloopOp)});
2415
+ pushCancelFinalizationCB (cancelTerminators, builder, *ompBuilder, wsloopOp,
2416
+ llvm::omp::Directive::OMPD_for);
2387
2417
2388
2418
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
2389
2419
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
@@ -2406,18 +2436,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2406
2436
if (failed (handleError (wsloopIP, opInst)))
2407
2437
return failure ();
2408
2438
2409
- ompBuilder->popFinalizationCB ();
2410
- if (!cancelTerminators.empty ()) {
2411
- // If we cancelled the loop, we should branch to the finalization block of
2412
- // the wsloop (which is always immediately before the loop continuation
2413
- // block). Now the finalization has been created, we can fix the branch.
2414
- llvm::BasicBlock *wsloopFini = wsloopIP->getBlock ()->getSinglePredecessor ();
2415
- for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2416
- assert (cancelBranch->getNumSuccessors () == 1 &&
2417
- " cancel branch should have one target" );
2418
- cancelBranch->setSuccessor (0 , wsloopFini);
2419
- }
2420
- }
2439
+ // Set the correct branch target for task cancellation
2440
+ popCancelFinalizationCB (cancelTerminators, *ompBuilder, wsloopIP.get ());
2421
2441
2422
2442
// Process the reductions if required.
2423
2443
if (failed (createReductionsAndCleanup (
@@ -3075,9 +3095,6 @@ convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
3075
3095
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3076
3096
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3077
3097
3078
- if (failed (checkImplementationStatus (*op.getOperation ())))
3079
- return failure ();
3080
-
3081
3098
llvm::Value *ifCond = nullptr ;
3082
3099
if (Value ifVar = op.getIfExpr ())
3083
3100
ifCond = moduleTranslation.lookupValue (ifVar);
@@ -3103,9 +3120,6 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
3103
3120
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3104
3121
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3105
3122
3106
- if (failed (checkImplementationStatus (*op.getOperation ())))
3107
- return failure ();
3108
-
3109
3123
llvm::omp::Directive cancelledDirective =
3110
3124
convertCancellationConstructType (op.getCancelDirective ());
3111
3125
0 commit comments