Skip to content

[mlir][OpenMP] convert wsloop cancellation to LLVMIR #137194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
omp::ClauseCancellationConstructType cancelledDirective =
op.getCancelDirective();
if (cancelledDirective != omp::ClauseCancellationConstructType::Parallel &&
cancelledDirective != omp::ClauseCancellationConstructType::Sections)
if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
result = todo("cancel directive construct type not yet supported");
};
auto checkDepend = [&todo](auto op, LogicalResult &result) {
Expand Down Expand Up @@ -2358,6 +2357,30 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
? llvm::omp::WorksharingLoopType::DistributeForStaticLoop
: llvm::omp::WorksharingLoopType::ForStaticLoop;

SmallVector<llvm::BranchInst *> cancelTerminators;
// This callback is invoked only if there is cancellation inside of the wsloop
// body.
auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);

// ip is currently in the block branched to if cancellation occured.
// We need to create a branch to terminate that block.
llvmBuilder.restoreIP(ip);

// We must still clean up the wsloop after cancelling it, so we need to
// branch to the block that finalizes the wsloop.
// That block has not been created yet so use this block as a dummy for now
// and fix this after creating the wsloop.
cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
return llvm::Error::success();
};
// We have to add the cleanup to the OpenMPIRBuilder before the body gets
// created in case the body contains omp.cancel (which will then expect to be
// able to find this cleanup callback).
ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
constructIsCancellable(wsloopOp)});

llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation);
Expand All @@ -2379,6 +2402,19 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();

ompBuilder->popFinalizationCB();
if (!cancelTerminators.empty()) {
// If we cancelled the loop, we should branch to the finalization block of
// the wsloop (which is always immediately before the loop continuation
// block). Now the finalization has been created, we can fix the branch.
llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
for (llvm::BranchInst *cancelBranch : cancelTerminators) {
assert(cancelBranch->getNumSuccessors() == 1 &&
"cancel branch should have one target");
cancelBranch->setSuccessor(0, wsloopFini);
}
}

// Process the reductions if required.
if (failed(createReductionsAndCleanup(
wsloopOp, builder, moduleTranslation, allocaIP, reductionDecls,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

llvm.func @cancel_distribute_parallel_do(%lb : i32, %ub : i32, %step : i32) {
omp.teams {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
omp.cancel cancellation_construct_type(loop)
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
llvm.return
}
// CHECK-LABEL: define internal void @cancel_distribute_parallel_do..omp_par
// [...]
// CHECK: omp_loop.cond:
// CHECK: %[[VAL_102:.*]] = icmp ult i32 %{{.*}}, %{{.*}}
// CHECK: br i1 %[[VAL_102]], label %omp_loop.body, label %omp_loop.exit
// CHECK: omp_loop.exit:
// CHECK: call void @__kmpc_for_static_fini(
// CHECK: %[[VAL_106:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_106]])
// CHECK: br label %omp_loop.after
// CHECK: omp_loop.after:
// CHECK: br label %omp.region.cont6
// CHECK: omp.region.cont6:
// CHECK: br label %omp.region.cont4
// CHECK: omp.region.cont4:
// CHECK: br label %distribute.exit.exitStub
// CHECK: omp_loop.body:
// CHECK: %[[VAL_111:.*]] = add i32 %{{.*}}, %{{.*}}
// CHECK: %[[VAL_112:.*]] = mul i32 %[[VAL_111]], %{{.*}}
// CHECK: %[[VAL_113:.*]] = add i32 %[[VAL_112]], %{{.*}}
// CHECK: br label %omp.loop_nest.region
// CHECK: omp.loop_nest.region:
// CHECK: %[[VAL_115:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_116:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_115]], i32 2)
// CHECK: %[[VAL_117:.*]] = icmp eq i32 %[[VAL_116]], 0
// CHECK: br i1 %[[VAL_117]], label %omp.loop_nest.region.split, label %omp.loop_nest.region.cncl
// CHECK: omp.loop_nest.region.cncl:
// CHECK: br label %omp_loop.exit
// CHECK: omp.loop_nest.region.split:
// CHECK: br label %omp.region.cont7
// CHECK: omp.region.cont7:
// CHECK: br label %omp_loop.inc
// CHECK: omp_loop.inc:
// CHECK: %[[VAL_100:.*]] = add nuw i32 %{{.*}}, 1
// CHECK: br label %omp_loop.header
// CHECK: distribute.exit.exitStub:
// CHECK: ret void

87 changes: 87 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-cancel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,90 @@ llvm.func @cancel_sections_if(%cond : i1) {
// CHECK: ret void
// CHECK: .cncl: ; preds = %[[VAL_27]]
// CHECK: br label %[[VAL_19]]

llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
omp.cancel cancellation_construct_type(loop) if(%cond)
omp.yield
}
}
llvm.return
}
// CHECK-LABEL: define void @cancel_wsloop_if
// CHECK: %[[VAL_0:.*]] = alloca i32, align 4
// CHECK: %[[VAL_1:.*]] = alloca i32, align 4
// CHECK: %[[VAL_2:.*]] = alloca i32, align 4
// CHECK: %[[VAL_3:.*]] = alloca i32, align 4
// CHECK: br label %[[VAL_4:.*]]
// CHECK: omp.region.after_alloca: ; preds = %[[VAL_5:.*]]
// CHECK: br label %[[VAL_6:.*]]
// CHECK: entry: ; preds = %[[VAL_4]]
// CHECK: br label %[[VAL_7:.*]]
// CHECK: omp.wsloop.region: ; preds = %[[VAL_6]]
// CHECK: %[[VAL_8:.*]] = icmp slt i32 %[[VAL_9:.*]], 0
// CHECK: %[[VAL_10:.*]] = sub i32 0, %[[VAL_9]]
// CHECK: %[[VAL_11:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_10]], i32 %[[VAL_9]]
// CHECK: %[[VAL_12:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_13:.*]], i32 %[[VAL_14:.*]]
// CHECK: %[[VAL_15:.*]] = select i1 %[[VAL_8]], i32 %[[VAL_14]], i32 %[[VAL_13]]
// CHECK: %[[VAL_16:.*]] = sub nsw i32 %[[VAL_15]], %[[VAL_12]]
// CHECK: %[[VAL_17:.*]] = icmp sle i32 %[[VAL_15]], %[[VAL_12]]
// CHECK: %[[VAL_18:.*]] = sub i32 %[[VAL_16]], 1
// CHECK: %[[VAL_19:.*]] = udiv i32 %[[VAL_18]], %[[VAL_11]]
// CHECK: %[[VAL_20:.*]] = add i32 %[[VAL_19]], 1
// CHECK: %[[VAL_21:.*]] = icmp ule i32 %[[VAL_16]], %[[VAL_11]]
// CHECK: %[[VAL_22:.*]] = select i1 %[[VAL_21]], i32 1, i32 %[[VAL_20]]
// CHECK: %[[VAL_23:.*]] = select i1 %[[VAL_17]], i32 0, i32 %[[VAL_22]]
// CHECK: br label %[[VAL_24:.*]]
// CHECK: omp_loop.preheader: ; preds = %[[VAL_7]]
// CHECK: store i32 0, ptr %[[VAL_1]], align 4
// CHECK: %[[VAL_25:.*]] = sub i32 %[[VAL_23]], 1
// CHECK: store i32 %[[VAL_25]], ptr %[[VAL_2]], align 4
// CHECK: store i32 1, ptr %[[VAL_3]], align 4
// CHECK: %[[VAL_26:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_26]], i32 34, ptr %[[VAL_0]], ptr %[[VAL_1]], ptr %[[VAL_2]], ptr %[[VAL_3]], i32 1, i32 0)
// CHECK: %[[VAL_27:.*]] = load i32, ptr %[[VAL_1]], align 4
// CHECK: %[[VAL_28:.*]] = load i32, ptr %[[VAL_2]], align 4
// CHECK: %[[VAL_29:.*]] = sub i32 %[[VAL_28]], %[[VAL_27]]
// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
// CHECK: br label %[[VAL_31:.*]]
// CHECK: omp_loop.header: ; preds = %[[VAL_32:.*]], %[[VAL_24]]
// CHECK: %[[VAL_33:.*]] = phi i32 [ 0, %[[VAL_24]] ], [ %[[VAL_34:.*]], %[[VAL_32]] ]
// CHECK: br label %[[VAL_35:.*]]
// CHECK: omp_loop.cond: ; preds = %[[VAL_31]]
// CHECK: %[[VAL_36:.*]] = icmp ult i32 %[[VAL_33]], %[[VAL_30]]
// CHECK: br i1 %[[VAL_36]], label %[[VAL_37:.*]], label %[[VAL_38:.*]]
// CHECK: omp_loop.body: ; preds = %[[VAL_35]]
// CHECK: %[[VAL_39:.*]] = add i32 %[[VAL_33]], %[[VAL_27]]
// CHECK: %[[VAL_40:.*]] = mul i32 %[[VAL_39]], %[[VAL_9]]
// CHECK: %[[VAL_41:.*]] = add i32 %[[VAL_40]], %[[VAL_14]]
// CHECK: br label %[[VAL_42:.*]]
// CHECK: omp.loop_nest.region: ; preds = %[[VAL_37]]
// CHECK: br i1 %[[VAL_43:.*]], label %[[VAL_44:.*]], label %[[VAL_45:.*]]
// CHECK: 25: ; preds = %[[VAL_42]]
// CHECK: %[[VAL_46:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: %[[VAL_47:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_46]], i32 2)
// CHECK: %[[VAL_48:.*]] = icmp eq i32 %[[VAL_47]], 0
// CHECK: br i1 %[[VAL_48]], label %[[VAL_49:.*]], label %[[VAL_50:.*]]
// CHECK: .split: ; preds = %[[VAL_44]]
// CHECK: br label %[[VAL_51:.*]]
// CHECK: 28: ; preds = %[[VAL_42]]
// CHECK: br label %[[VAL_51]]
// CHECK: 29: ; preds = %[[VAL_45]], %[[VAL_49]]
// CHECK: br label %[[VAL_52:.*]]
// CHECK: omp.region.cont1: ; preds = %[[VAL_51]]
// CHECK: br label %[[VAL_32]]
// CHECK: omp_loop.inc: ; preds = %[[VAL_52]]
// CHECK: %[[VAL_34]] = add nuw i32 %[[VAL_33]], 1
// CHECK: br label %[[VAL_31]]
// CHECK: omp_loop.exit: ; preds = %[[VAL_50]], %[[VAL_35]]
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_26]])
// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_53]])
// CHECK: br label %[[VAL_54:.*]]
// CHECK: omp_loop.after: ; preds = %[[VAL_38]]
// CHECK: br label %[[VAL_55:.*]]
// CHECK: omp.region.cont: ; preds = %[[VAL_54]]
// CHECK: ret void
// CHECK: .cncl: ; preds = %[[VAL_44]]
// CHECK: br label %[[VAL_38]]
16 changes: 0 additions & 16 deletions mlir/test/Target/LLVMIR/openmp-todo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {

// -----

llvm.func @cancel_wsloop(%lb : i32, %ub : i32, %step: i32) {
// expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}
omp.wsloop {
// expected-error@below {{LLVM Translation failed for operation: omp.loop_nest}}
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
// expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
// expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
omp.cancel cancellation_construct_type(loop)
omp.yield
}
}
llvm.return
}

// -----

llvm.func @cancel_taskgroup() {
// expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
omp.taskgroup {
Expand Down