Skip to content

[mlir][OpenMP] cancel(lation point) taskgroup LLVMIR #137841

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flang/docs/OpenMPSupport.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
| depend clause | P | depend clause with array sections are not supported |
| declare reduction construct | N | |
| atomic construct extensions | Y | |
| cancel construct | N | |
| cancellation point construct | N | |
| cancel construct | Y | |
| cancellation point construct | Y | |
| parallel do simd construct | P | linear clause is not supported |
| target teams construct | P | device and reduction clauses are not supported |
| teams distribute construct | P | reduction and dist_schedule clauses not supported |
Expand Down
124 changes: 82 additions & 42 deletions mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,18 @@ static LogicalResult checkImplementationStatus(Operation &op) {
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
omp::ClauseCancellationConstructType cancelledDirective =
op.getCancelDirective();
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
result = todo("cancel directive construct type not yet supported");
// Cancelling a taskloop is not yet supported because we don't yet have LLVM
// IR conversion for taskloop
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup) {
Operation *parent = op->getParentOp();
while (parent) {
if (parent->getDialect() == op->getDialect())
break;
parent = parent->getParentOp();
}
if (isa_and_nonnull<omp::TaskloopOp>(parent))
result = todo("cancel directive inside of taskloop");
}
};
auto checkDepend = [&todo](auto op, LogicalResult &result) {
if (!op.getDependVars().empty() || op.getDependKinds())
Expand Down Expand Up @@ -1889,6 +1899,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
}
}

/// Shared implementation of a callback which adds a termiator for the new block
/// created for the branch taken when an openmp construct is cancelled. The
/// terminator is saved in \p cancelTerminators. This callback is invoked only
/// if there is cancellation inside of the taskgroup body.
/// The terminator will need to be fixed to branch to the correct block to
/// cleanup the construct.
static void
pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
llvm::IRBuilderBase &llvmBuilder,
llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
llvm::omp::Directive cancelDirective) {
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);

// ip is currently in the block branched to if cancellation occured.
// We need to create a branch to terminate that block.
llvmBuilder.restoreIP(ip);

// We must still clean up the construct after cancelling it, so we need to
// branch to the block that finalizes the taskgroup.
// That block has not been created yet so use this block as a dummy for now
// and fix this after creating the operation.
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
return llvm::Error::success();
};
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
// created in case the body contains omp.cancel (which will then expect to be
// able to find this cleanup callback).
ompBuilder.pushFinalizationCB(
{finiCB, cancelDirective, constructIsCancellable(op)});
}

/// If we cancelled the construct, we should branch to the finalization block of
/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
/// is immediately before the continuation block. Now this finalization has
/// been created we can fix the branch.
static void
popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
llvm::OpenMPIRBuilder &ompBuilder,
const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
ompBuilder.popFinalizationCB();
llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
assert(cancelBranch->getNumSuccessors() == 1 &&
"cancel branch should have one target");
cancelBranch->setSuccessor(0, constructFini);
}
}

namespace {
/// TaskContextStructManager takes care of creating and freeing a structure
/// containing information needed by the task body to execute.
Expand Down Expand Up @@ -2202,6 +2261,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return llvm::Error::success();
};

llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
SmallVector<llvm::BranchInst *> cancelTerminators;
// The directive to match here is OMPD_taskgroup because it is the taskgroup
// which is canceled. This is handled here because it is the task's cleanup
// block which should be branched to.
pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
llvm::omp::Directive::OMPD_taskgroup);

SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
moduleTranslation, dds);
Expand All @@ -2219,6 +2286,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
if (failed(handleError(afterIP, *taskOp)))
return failure();

// Set the correct branch target for task cancellation
popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());

builder.restoreIP(*afterIP);
return success();
}
Expand Down Expand Up @@ -2349,28 +2419,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
: llvm::omp::WorksharingLoopType::ForStaticLoop;

SmallVector<llvm::BranchInst *> cancelTerminators;
// This callback is invoked only if there is cancellation inside of the wsloop
// body.
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);

// ip is currently in the block branched to if cancellation occured.
// We need to create a branch to terminate that block.
llvmBuilder.restoreIP(ip);

// We must still clean up the wsloop after cancelling it, so we need to
// branch to the block that finalizes the wsloop.
// That block has not been created yet so use this block as a dummy for now
// and fix this after creating the wsloop.
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
return llvm::Error::success();
};
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
// created in case the body contains omp.cancel (which will then expect to be
// able to find this cleanup callback).
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
constructIsCancellable(wsloopOp)});
pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
llvm::omp::Directive::OMPD_for);

llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
Expand All @@ -2393,18 +2443,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();

ompBuilder->popFinalizationCB();
if (!cancelTerminators.empty()) {
// If we cancelled the loop, we should branch to the finalization block of
// the wsloop (which is always immediately before the loop continuation
// block). Now the finalization has been created, we can fix the branch.
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
assert(cancelBranch->getNumSuccessors() == 1 &&
"cancel branch should have one target");
cancelBranch->setSuccessor(0, wsloopFini);
}
}
// Set the correct branch target for task cancellation
popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());

// Process the reductions if required.
if (failed(createReductionsAndCleanup(
Expand Down Expand Up @@ -3060,12 +3100,12 @@ static llvm::omp::Directive convertCancellationConstructType(
static LogicalResult
convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();

if (failed(checkImplementationStatus(*op.getOperation())))
return failure();

llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();

llvm::Value *ifCond = nullptr;
if (Value ifVar = op.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
Expand All @@ -3088,12 +3128,12 @@ static LogicalResult
convertOmpCancellationPoint(omp::CancellationPointOp op,
llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();

if (failed(checkImplementationStatus(*op.getOperation())))
return failure();

llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();

llvm::omp::Directive cancelledDirective =
convertCancellationConstructType(op.getCancelDirective());

Expand Down
48 changes: 48 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-cancel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,51 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
// CHECK: ret void
// CHECK: .cncl: ; preds = %[[VAL_44]]
// CHECK: br label %[[VAL_38]]

omp.private {type = firstprivate} @i32_priv : i32 copy {
^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
%0 = llvm.load %arg0 : !llvm.ptr -> i32
llvm.store %0, %arg1 : i32, !llvm.ptr
omp.yield(%arg1 : !llvm.ptr)
}

llvm.func @do_something(!llvm.ptr)

llvm.func @cancel_taskgroup(%arg0: !llvm.ptr) {
omp.taskgroup {
// Using firstprivate clause so we have some end of task cleanup to branch to
// after the cancellation.
omp.task private(@i32_priv %arg0 -> %arg1 : !llvm.ptr) {
omp.cancel cancellation_construct_type(taskgroup)
llvm.call @do_something(%arg1) : (!llvm.ptr) -> ()
omp.terminator
}
omp.terminator
}
llvm.return
}
// CHECK-LABEL: define internal void @cancel_taskgroup..omp_par(
// CHECK: task.alloca:
// CHECK: %[[VAL_21:.*]] = load ptr, ptr %[[VAL_22:.*]], align 8
// CHECK: %[[VAL_23:.*]] = getelementptr { ptr }, ptr %[[VAL_21]], i32 0, i32 0
// CHECK: %[[VAL_24:.*]] = load ptr, ptr %[[VAL_23]], align 8, !align !1
// CHECK: br label %[[VAL_25:.*]]
// CHECK: task.body: ; preds = %[[VAL_26:.*]]
// CHECK: %[[VAL_27:.*]] = getelementptr { i32 }, ptr %[[VAL_24]], i32 0, i32 0
// CHECK: br label %[[VAL_28:.*]]
// CHECK: omp.task.region: ; preds = %[[VAL_25]]
// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 4)
// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
// CHECK: br i1 %[[VAL_31]], label %omp.task.region.split, label %omp.task.region.cncl
// CHECK: omp.task.region.cncl:
// CHECK: br label %omp.region.cont2
// CHECK: omp.region.cont2:
// Both cancellation and normal paths reach the end-of-task cleanup:
// CHECK: tail call void @free(ptr %[[VAL_24]])
// CHECK: br label %task.exit.exitStub
// CHECK: omp.task.region.split:
// CHECK: call void @do_something(ptr %[[VAL_27]])
// CHECK: br label %omp.region.cont2
// CHECK: task.exit.exitStub:
// CHECK: ret void
30 changes: 30 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,33 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
// CHECK: ret void
// CHECK: omp.loop_nest.region.cncl: ; preds = %[[VAL_100]]
// CHECK: br label %[[VAL_96]]


llvm.func @cancellation_point_taskgroup() {
omp.taskgroup {
omp.task {
omp.cancellation_point cancellation_construct_type(taskgroup)
omp.terminator
}
omp.terminator
}
llvm.return
}
// CHECK-LABEL: define internal void @cancellation_point_taskgroup..omp_par(
// CHECK: task.alloca:
// CHECK: br label %[[VAL_50:.*]]
// CHECK: task.body: ; preds = %[[VAL_51:.*]]
// CHECK: br label %[[VAL_52:.*]]
// CHECK: omp.task.region: ; preds = %[[VAL_50]]
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_54:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_53]], i32 4)
// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_54]], 0
// CHECK: br i1 %[[VAL_55]], label %omp.task.region.split, label %omp.task.region.cncl
// CHECK: omp.task.region.cncl:
// CHECK: br label %omp.region.cont1
// CHECK: omp.region.cont1:
// CHECK: br label %task.exit.exitStub
// CHECK: omp.task.region.split:
// CHECK: br label %omp.region.cont1
// CHECK: task.exit.exitStub:
// CHECK: ret void
34 changes: 0 additions & 34 deletions mlir/test/Target/LLVMIR/openmp-todo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,40 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {

// -----

llvm.func @cancel_taskgroup() {
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
omp.taskgroup {
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
omp.task {
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
omp.cancel cancellation_construct_type(taskgroup)
omp.terminator
}
omp.terminator
}
llvm.return
}

// -----

llvm.func @cancellation_point_taskgroup() {
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
omp.taskgroup {
// expected-error@below {{LLVM Translation failed for operation: omp.task}}
omp.task {
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancellation_point operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.cancellation_point}}
omp.cancellation_point cancellation_construct_type(taskgroup)
omp.terminator
}
omp.terminator
}
llvm.return
}

// -----

llvm.func @do_simd(%lb : i32, %ub : i32, %step : i32) {
omp.wsloop {
// expected-warning@below {{simd information on composite construct discarded}}
Expand Down
Loading