@@ -158,12 +158,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
158
158
if (op.getBare ())
159
159
result = todo (" ompx_bare" );
160
160
};
161
- auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162
- omp::ClauseCancellationConstructType cancelledDirective =
163
- op.getCancelDirective ();
164
- if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
165
- result = todo (" cancel directive construct type not yet supported" );
166
- };
167
161
auto checkDepend = [&todo](auto op, LogicalResult &result) {
168
162
if (!op.getDependVars ().empty () || op.getDependKinds ())
169
163
result = todo (" depend" );
@@ -240,10 +234,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
240
234
241
235
LogicalResult result = success ();
242
236
llvm::TypeSwitch<Operation &>(op)
243
- .Case ([&](omp::CancelOp op) { checkCancelDirective (op, result); })
244
- .Case ([&](omp::CancellationPointOp op) {
245
- checkCancelDirective (op, result);
246
- })
247
237
.Case ([&](omp::DistributeOp op) {
248
238
checkAllocate (op, result);
249
239
checkDistSchedule (op, result);
@@ -1889,6 +1879,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
1889
1879
}
1890
1880
}
1891
1881
1882
+ // / Shared implementation of a callback which adds a termiator for the new block
1883
+ // / created for the branch taken when an openmp construct is cancelled. The
1884
+ // / terminator is saved in \p cancelTerminators. This callback is invoked only
1885
+ // / if there is cancellation inside of the taskgroup body.
1886
+ // / The terminator will need to be fixed to branch to the correct block to
1887
+ // / cleanup the construct.
1888
+ static void
1889
+ pushCancelFinalizationCB (SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
1890
+ llvm::IRBuilderBase &llvmBuilder,
1891
+ llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
1892
+ llvm::omp::Directive cancelDirective) {
1893
+ auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1894
+ llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
1895
+
1896
+ // ip is currently in the block branched to if cancellation occured.
1897
+ // We need to create a branch to terminate that block.
1898
+ llvmBuilder.restoreIP (ip);
1899
+
1900
+ // We must still clean up the construct after cancelling it, so we need to
1901
+ // branch to the block that finalizes the taskgroup.
1902
+ // That block has not been created yet so use this block as a dummy for now
1903
+ // and fix this after creating the operation.
1904
+ cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
1905
+ return llvm::Error::success ();
1906
+ };
1907
+ // We have to add the cleanup to the OpenMPIRBuilder before the body gets
1908
+ // created in case the body contains omp.cancel (which will then expect to be
1909
+ // able to find this cleanup callback).
1910
+ ompBuilder.pushFinalizationCB (
1911
+ {finiCB, cancelDirective, constructIsCancellable (op)});
1912
+ }
1913
+
1914
+ // / If we cancelled the construct, we should branch to the finalization block of
1915
+ // / that construct. OMPIRBuilder structures the CFG such that the cleanup block
1916
+ // / is immediately before the continuation block. Now this finalization has
1917
+ // / been created we can fix the branch.
1918
+ static void
1919
+ popCancelFinalizationCB (const ArrayRef<llvm::BranchInst *> cancelTerminators,
1920
+ llvm::OpenMPIRBuilder &ompBuilder,
1921
+ const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1922
+ ompBuilder.popFinalizationCB ();
1923
+ llvm::BasicBlock *constructFini = afterIP.getBlock ()->getSinglePredecessor ();
1924
+ for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1925
+ assert (cancelBranch->getNumSuccessors () == 1 &&
1926
+ " cancel branch should have one target" );
1927
+ cancelBranch->setSuccessor (0 , constructFini);
1928
+ }
1929
+ }
1930
+
1892
1931
namespace {
1893
1932
// / TaskContextStructManager takes care of creating and freeing a structure
1894
1933
// / containing information needed by the task body to execute.
@@ -2202,6 +2241,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2202
2241
return llvm::Error::success ();
2203
2242
};
2204
2243
2244
+ llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder ();
2245
+ SmallVector<llvm::BranchInst *> cancelTerminators;
2246
+ // The directive to match here is OMPD_taskgroup because it is the taskgroup
2247
+ // which is canceled. This is handled here because it is the task's cleanup
2248
+ // block which should be branched to.
2249
+ pushCancelFinalizationCB (cancelTerminators, builder, ompBuilder, taskOp,
2250
+ llvm::omp::Directive::OMPD_taskgroup);
2251
+
2205
2252
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
2206
2253
buildDependData (taskOp.getDependKinds (), taskOp.getDependVars (),
2207
2254
moduleTranslation, dds);
@@ -2219,6 +2266,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2219
2266
if (failed (handleError (afterIP, *taskOp)))
2220
2267
return failure ();
2221
2268
2269
+ // Set the correct branch target for task cancellation
2270
+ popCancelFinalizationCB (cancelTerminators, ompBuilder, afterIP.get ());
2271
+
2222
2272
builder.restoreIP (*afterIP);
2223
2273
return success ();
2224
2274
}
@@ -2349,28 +2399,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2349
2399
: llvm::omp::WorksharingLoopType::ForStaticLoop;
2350
2400
2351
2401
SmallVector<llvm::BranchInst *> cancelTerminators;
2352
- // This callback is invoked only if there is cancellation inside of the wsloop
2353
- // body.
2354
- auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2355
- llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder ;
2356
- llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
2357
-
2358
- // ip is currently in the block branched to if cancellation occured.
2359
- // We need to create a branch to terminate that block.
2360
- llvmBuilder.restoreIP (ip);
2361
-
2362
- // We must still clean up the wsloop after cancelling it, so we need to
2363
- // branch to the block that finalizes the wsloop.
2364
- // That block has not been created yet so use this block as a dummy for now
2365
- // and fix this after creating the wsloop.
2366
- cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
2367
- return llvm::Error::success ();
2368
- };
2369
- // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2370
- // created in case the body contains omp.cancel (which will then expect to be
2371
- // able to find this cleanup callback).
2372
- ompBuilder->pushFinalizationCB ({finiCB, llvm::omp::Directive::OMPD_for,
2373
- constructIsCancellable (wsloopOp)});
2402
+ pushCancelFinalizationCB (cancelTerminators, builder, *ompBuilder, wsloopOp,
2403
+ llvm::omp::Directive::OMPD_for);
2374
2404
2375
2405
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
2376
2406
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
@@ -2393,18 +2423,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2393
2423
if (failed (handleError (wsloopIP, opInst)))
2394
2424
return failure ();
2395
2425
2396
- ompBuilder->popFinalizationCB ();
2397
- if (!cancelTerminators.empty ()) {
2398
- // If we cancelled the loop, we should branch to the finalization block of
2399
- // the wsloop (which is always immediately before the loop continuation
2400
- // block). Now the finalization has been created, we can fix the branch.
2401
- llvm::BasicBlock *wsloopFini = wsloopIP->getBlock ()->getSinglePredecessor ();
2402
- for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2403
- assert (cancelBranch->getNumSuccessors () == 1 &&
2404
- " cancel branch should have one target" );
2405
- cancelBranch->setSuccessor (0 , wsloopFini);
2406
- }
2407
- }
2426
+ // Set the correct branch target for task cancellation
2427
+ popCancelFinalizationCB (cancelTerminators, *ompBuilder, wsloopIP.get ());
2408
2428
2409
2429
// Process the reductions if required.
2410
2430
if (failed (createReductionsAndCleanup (
@@ -3063,9 +3083,6 @@ convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
3063
3083
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3064
3084
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3065
3085
3066
- if (failed (checkImplementationStatus (*op.getOperation ())))
3067
- return failure ();
3068
-
3069
3086
llvm::Value *ifCond = nullptr ;
3070
3087
if (Value ifVar = op.getIfExpr ())
3071
3088
ifCond = moduleTranslation.lookupValue (ifVar);
@@ -3091,9 +3108,6 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
3091
3108
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3092
3109
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3093
3110
3094
- if (failed (checkImplementationStatus (*op.getOperation ())))
3095
- return failure ();
3096
-
3097
3111
llvm::omp::Directive cancelledDirective =
3098
3112
convertCancellationConstructType (op.getCancelDirective ());
3099
3113
0 commit comments