@@ -161,8 +161,18 @@ static LogicalResult checkImplementationStatus(Operation &op) {
161
161
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
162
162
omp::ClauseCancellationConstructType cancelledDirective =
163
163
op.getCancelDirective ();
164
- if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
165
- result = todo (" cancel directive construct type not yet supported" );
164
+ // Cancelling a taskloop is not yet supported because we don't yet have LLVM
165
+ // IR conversion for taskloop
166
+ if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
167
+ Operation *parent = op->getParentOp ();
168
+ while (parent) {
169
+ if (parent->getDialect () == op->getDialect ())
170
+ break ;
171
+ parent = parent->getParentOp ();
172
+ }
173
+ if (isa_and_nonnull<omp::TaskloopOp>(parent))
174
+ result = todo (" cancel directive inside of taskloop" );
175
+ }
166
176
};
167
177
auto checkDepend = [&todo](auto op, LogicalResult &result) {
168
178
if (!op.getDependVars ().empty () || op.getDependKinds ())
@@ -1889,6 +1899,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
1889
1899
}
1890
1900
}
1891
1901
1902
+ // / Shared implementation of a callback which adds a termiator for the new block
1903
+ // / created for the branch taken when an openmp construct is cancelled. The
1904
+ // / terminator is saved in \p cancelTerminators. This callback is invoked only
1905
+ // / if there is cancellation inside of the taskgroup body.
1906
+ // / The terminator will need to be fixed to branch to the correct block to
1907
+ // / cleanup the construct.
1908
+ static void
1909
+ pushCancelFinalizationCB (SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
1910
+ llvm::IRBuilderBase &llvmBuilder,
1911
+ llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
1912
+ llvm::omp::Directive cancelDirective) {
1913
+ auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
1914
+ llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
1915
+
1916
+ // ip is currently in the block branched to if cancellation occured.
1917
+ // We need to create a branch to terminate that block.
1918
+ llvmBuilder.restoreIP (ip);
1919
+
1920
+ // We must still clean up the construct after cancelling it, so we need to
1921
+ // branch to the block that finalizes the taskgroup.
1922
+ // That block has not been created yet so use this block as a dummy for now
1923
+ // and fix this after creating the operation.
1924
+ cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
1925
+ return llvm::Error::success ();
1926
+ };
1927
+ // We have to add the cleanup to the OpenMPIRBuilder before the body gets
1928
+ // created in case the body contains omp.cancel (which will then expect to be
1929
+ // able to find this cleanup callback).
1930
+ ompBuilder.pushFinalizationCB (
1931
+ {finiCB, cancelDirective, constructIsCancellable (op)});
1932
+ }
1933
+
1934
+ // / If we cancelled the construct, we should branch to the finalization block of
1935
+ // / that construct. OMPIRBuilder structures the CFG such that the cleanup block
1936
+ // / is immediately before the continuation block. Now this finalization has
1937
+ // / been created we can fix the branch.
1938
+ static void
1939
+ popCancelFinalizationCB (const ArrayRef<llvm::BranchInst *> cancelTerminators,
1940
+ llvm::OpenMPIRBuilder &ompBuilder,
1941
+ const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
1942
+ ompBuilder.popFinalizationCB ();
1943
+ llvm::BasicBlock *constructFini = afterIP.getBlock ()->getSinglePredecessor ();
1944
+ for (llvm::BranchInst *cancelBranch : cancelTerminators) {
1945
+ assert (cancelBranch->getNumSuccessors () == 1 &&
1946
+ " cancel branch should have one target" );
1947
+ cancelBranch->setSuccessor (0 , constructFini);
1948
+ }
1949
+ }
1950
+
1892
1951
namespace {
1893
1952
// / TaskContextStructManager takes care of creating and freeing a structure
1894
1953
// / containing information needed by the task body to execute.
@@ -2202,6 +2261,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2202
2261
return llvm::Error::success ();
2203
2262
};
2204
2263
2264
+ llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder ();
2265
+ SmallVector<llvm::BranchInst *> cancelTerminators;
2266
+ // The directive to match here is OMPD_taskgroup because it is the taskgroup
2267
+ // which is canceled. This is handled here because it is the task's cleanup
2268
+ // block which should be branched to.
2269
+ pushCancelFinalizationCB (cancelTerminators, builder, ompBuilder, taskOp,
2270
+ llvm::omp::Directive::OMPD_taskgroup);
2271
+
2205
2272
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
2206
2273
buildDependData (taskOp.getDependKinds (), taskOp.getDependVars (),
2207
2274
moduleTranslation, dds);
@@ -2219,6 +2286,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
2219
2286
if (failed (handleError (afterIP, *taskOp)))
2220
2287
return failure ();
2221
2288
2289
+ // Set the correct branch target for task cancellation
2290
+ popCancelFinalizationCB (cancelTerminators, ompBuilder, afterIP.get ());
2291
+
2222
2292
builder.restoreIP (*afterIP);
2223
2293
return success ();
2224
2294
}
@@ -2349,28 +2419,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2349
2419
: llvm::omp::WorksharingLoopType::ForStaticLoop;
2350
2420
2351
2421
SmallVector<llvm::BranchInst *> cancelTerminators;
2352
- // This callback is invoked only if there is cancellation inside of the wsloop
2353
- // body.
2354
- auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
2355
- llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder ;
2356
- llvm::IRBuilderBase::InsertPointGuard guard (llvmBuilder);
2357
-
2358
- // ip is currently in the block branched to if cancellation occured.
2359
- // We need to create a branch to terminate that block.
2360
- llvmBuilder.restoreIP (ip);
2361
-
2362
- // We must still clean up the wsloop after cancelling it, so we need to
2363
- // branch to the block that finalizes the wsloop.
2364
- // That block has not been created yet so use this block as a dummy for now
2365
- // and fix this after creating the wsloop.
2366
- cancelTerminators.push_back (llvmBuilder.CreateBr (ip.getBlock ()));
2367
- return llvm::Error::success ();
2368
- };
2369
- // We have to add the cleanup to the OpenMPIRBuilder before the body gets
2370
- // created in case the body contains omp.cancel (which will then expect to be
2371
- // able to find this cleanup callback).
2372
- ompBuilder->pushFinalizationCB ({finiCB, llvm::omp::Directive::OMPD_for,
2373
- constructIsCancellable (wsloopOp)});
2422
+ pushCancelFinalizationCB (cancelTerminators, builder, *ompBuilder, wsloopOp,
2423
+ llvm::omp::Directive::OMPD_for);
2374
2424
2375
2425
llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
2376
2426
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions (
@@ -2393,18 +2443,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
2393
2443
if (failed (handleError (wsloopIP, opInst)))
2394
2444
return failure ();
2395
2445
2396
- ompBuilder->popFinalizationCB ();
2397
- if (!cancelTerminators.empty ()) {
2398
- // If we cancelled the loop, we should branch to the finalization block of
2399
- // the wsloop (which is always immediately before the loop continuation
2400
- // block). Now the finalization has been created, we can fix the branch.
2401
- llvm::BasicBlock *wsloopFini = wsloopIP->getBlock ()->getSinglePredecessor ();
2402
- for (llvm::BranchInst *cancelBranch : cancelTerminators) {
2403
- assert (cancelBranch->getNumSuccessors () == 1 &&
2404
- " cancel branch should have one target" );
2405
- cancelBranch->setSuccessor (0 , wsloopFini);
2406
- }
2407
- }
2446
+ // Set the correct branch target for task cancellation
2447
+ popCancelFinalizationCB (cancelTerminators, *ompBuilder, wsloopIP.get ());
2408
2448
2409
2449
// Process the reductions if required.
2410
2450
if (failed (createReductionsAndCleanup (
@@ -3060,12 +3100,12 @@ static llvm::omp::Directive convertCancellationConstructType(
3060
3100
static LogicalResult
3061
3101
convertOmpCancel (omp::CancelOp op, llvm::IRBuilderBase &builder,
3062
3102
LLVM::ModuleTranslation &moduleTranslation) {
3063
- llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3064
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3065
-
3066
3103
if (failed (checkImplementationStatus (*op.getOperation ())))
3067
3104
return failure ();
3068
3105
3106
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3107
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3108
+
3069
3109
llvm::Value *ifCond = nullptr ;
3070
3110
if (Value ifVar = op.getIfExpr ())
3071
3111
ifCond = moduleTranslation.lookupValue (ifVar);
@@ -3088,12 +3128,12 @@ static LogicalResult
3088
3128
convertOmpCancellationPoint (omp::CancellationPointOp op,
3089
3129
llvm::IRBuilderBase &builder,
3090
3130
LLVM::ModuleTranslation &moduleTranslation) {
3091
- llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3092
- llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3093
-
3094
3131
if (failed (checkImplementationStatus (*op.getOperation ())))
3095
3132
return failure ();
3096
3133
3134
+ llvm::OpenMPIRBuilder::LocationDescription ompLoc (builder);
3135
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder ();
3136
+
3097
3137
llvm::omp::Directive cancelledDirective =
3098
3138
convertCancellationConstructType (op.getCancelDirective ());
3099
3139
0 commit comments