Skip to content

[mlir][llvm] Translation support for task detach #116601

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -1262,12 +1262,15 @@ class OpenMPIRBuilder {
/// cannot be resumed until execution of the structured
/// block that is associated with the generated task is
/// completed.
/// \param EventHandle If present, signifies the event handle as part of
/// the detach clause
/// \param Mergeable If the given task is `mergeable`
InsertPointOrErrorTy
createTask(const LocationDescription &Loc, InsertPointTy AllocaIP,
BodyGenCallbackTy BodyGenCB, bool Tied = true,
Value *Final = nullptr, Value *IfCondition = nullptr,
SmallVector<DependData> Dependencies = {}, bool Mergeable = false);
SmallVector<DependData> Dependencies = {}, bool Mergeable = false,
Value *EventHandle = nullptr);

/// Generator for the taskgroup construct
///
Expand Down
18 changes: 16 additions & 2 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1818,7 +1818,7 @@ static Value *emitTaskDependencies(
OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
const LocationDescription &Loc, InsertPointTy AllocaIP,
BodyGenCallbackTy BodyGenCB, bool Tied, Value *Final, Value *IfCondition,
SmallVector<DependData> Dependencies, bool Mergeable) {
SmallVector<DependData> Dependencies, bool Mergeable, Value *EventHandle) {

if (!updateToLocation(Loc))
return InsertPointTy();
Expand Down Expand Up @@ -1864,7 +1864,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
Builder, AllocaIP, ToBeDeleted, TaskAllocaIP, "global.tid", false));

OI.PostOutlineCB = [this, Ident, Tied, Final, IfCondition, Dependencies,
Mergeable, TaskAllocaBB,
Mergeable, EventHandle, TaskAllocaBB,
ToBeDeleted](Function &OutlinedFn) mutable {
// Replace the Stale CI by appropriate RTL function call.
assert(OutlinedFn.getNumUses() == 1 &&
Expand Down Expand Up @@ -1935,6 +1935,20 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTask(
/*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
/*task_func=*/&OutlinedFn});

// Emit detach clause initialization.
// evt = (typeof(evt))__kmpc_task_allow_completion_event(loc, tid,
// task_descriptor);
if (EventHandle) {
Function *TaskDetachFn = getOrCreateRuntimeFunctionPtr(
OMPRTL___kmpc_task_allow_completion_event);
llvm::Value *EventVal =
Builder.CreateCall(TaskDetachFn, {Ident, ThreadID, TaskData});
llvm::Value *EventHandleAddr =
Builder.CreatePointerBitCastOrAddrSpaceCast(EventHandle,
Builder.getPtrTy(0));
EventVal = Builder.CreatePtrToInt(EventVal, Builder.getInt64Ty());
Builder.CreateStore(EventVal, EventHandleAddr);
}
// Copy the arguments for outlined function
if (HasShareds) {
Value *Shareds = StaleCI->getArgOperand(1);
Expand Down
28 changes: 28 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,34 @@ class OpenMP_ParallelizationLevelClauseSkip<

def OpenMP_ParallelizationLevelClause : OpenMP_ParallelizationLevelClauseSkip<>;

//===----------------------------------------------------------------------===//
// OpenMPV5.2: [12.5.2] `detach` clause
//===----------------------------------------------------------------------===//

class OpenMP_DetachClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false>
: OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {

let traits = [BlockArgOpenMPOpInterface];

let arguments = (ins Optional<OpenMP_PointerLikeType>:$event_handle);

let optAssemblyFormat = [{
`detach` `(` $event_handle `:` type($event_handle) `)`
}];

let description = [{
The detach clause specifies that the task generated by the construct on which it appears is a
detachable task. A new allow-completion event is created and connected to the completion of the
associated task region. The original event-handle is updated to represent that allow-completion
event before the task data environment is created.
}];
}

def OpenMP_DetachClause : OpenMP_DetachClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [12.4] `priority` clause
//===----------------------------------------------------------------------===//
Expand Down
21 changes: 12 additions & 9 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -657,15 +657,18 @@ def DistributeOp : OpenMP_Op<"distribute", traits = [
// 2.10.1 task Construct
//===----------------------------------------------------------------------===//

def TaskOp : OpenMP_Op<"task", traits = [
AttrSizedOperandSegments, AutomaticAllocationScope,
OutlineableOpenMPOpInterface
], clauses = [
// TODO: Complete clause list (affinity, detach).
OpenMP_AllocateClause, OpenMP_DependClause, OpenMP_FinalClause,
OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_MergeableClause,
OpenMP_PriorityClause, OpenMP_PrivateClause, OpenMP_UntiedClause
], singleRegion = true> {
def TaskOp
: OpenMP_Op<"task",
traits = [AttrSizedOperandSegments, AutomaticAllocationScope,
OutlineableOpenMPOpInterface],
clauses = [
// TODO: Complete clause list (affinity, detach).
OpenMP_AllocateClause, OpenMP_DependClause,
OpenMP_FinalClause, OpenMP_IfClause,
OpenMP_InReductionClause, OpenMP_MergeableClause,
OpenMP_PriorityClause, OpenMP_PrivateClause,
OpenMP_UntiedClause, OpenMP_DetachClause],
singleRegion = true> {
let summary = "task construct";
let description = [{
The task construct defines an explicit task.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2303,7 +2303,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
clauses.priority, /*private_vars=*/clauses.privateVars,
/*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
clauses.untied);
clauses.untied, clauses.eventHandle);
}

LogicalResult TaskOp::verify() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1703,7 +1703,8 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
ompLoc, allocaIP, bodyCB, !taskOp.getUntied(),
moduleTranslation.lookupValue(taskOp.getFinal()),
moduleTranslation.lookupValue(taskOp.getIfExpr()), dds,
taskOp.getMergeable());
taskOp.getMergeable(),
moduleTranslation.lookupValue(taskOp.getEventHandle()));

if (failed(handleError(afterIP, *taskOp)))
return failure();
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1634,7 +1634,7 @@ func.func @omp_single_copyprivate(%data_var : memref<i32>) -> () {
// -----

func.func @omp_task_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
// expected-error @below {{'omp.task' op operand count (1) does not match with the total size (0) specified in attribute 'operandSegmentSizes'}}
"omp.task"(%data_var) ({
"omp.terminator"() : () -> ()
}) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
Expand Down
10 changes: 7 additions & 3 deletions mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1975,8 +1975,8 @@ func.func @omp_single_copyprivate(%data_var: memref<i32>) {
}

// CHECK-LABEL: @omp_task
// CHECK-SAME: (%[[bool_var:.*]]: i1, %[[i64_var:.*]]: i64, %[[i32_var:.*]]: i32, %[[data_var:.*]]: memref<i32>)
func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memref<i32>) {
// CHECK-SAME: (%[[bool_var:.*]]: i1, %[[i64_var:.*]]: i64, %[[i32_var:.*]]: i32, %[[data_var:.*]]: memref<i32>, %[[event_handle:.*]]: !llvm.ptr)
func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memref<i32>, %event_handle : !llvm.ptr) {

// Checking simple task
// CHECK: omp.task {
Expand Down Expand Up @@ -2054,7 +2054,11 @@ func.func @omp_task(%bool_var: i1, %i64_var: i64, %i32_var: i32, %data_var: memr
// CHECK: omp.terminator
omp.terminator
}

// Checking detach clause
// CHECK: omp.task detach(%[[event_handle]] : !llvm.ptr)
omp.task detach(%event_handle : !llvm.ptr){
omp.terminator
}
// Checking multiple clauses
// CHECK: omp.task allocate(%[[data_var]] : memref<i32> -> %[[data_var]] : memref<i32>)
omp.task allocate(%data_var : memref<i32> -> %data_var : memref<i32>)
Expand Down
17 changes: 17 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2505,6 +2505,23 @@ llvm.mlir.global internal @_QFsubEx() : i32

// -----

// CHECK-LABEL: define void @omp_task_detach
// CHECK-SAME: (ptr %[[event_handle:.*]])
llvm.func @omp_task_detach(%event_handle : !llvm.ptr){
// CHECK: %[[omp_global_thread_num:.+]] = call i32 @__kmpc_global_thread_num({{.+}})
// CHECK: %[[task_data:.+]] = call ptr @__kmpc_omp_task_alloc
// CHECK: %[[return_val:.*]] = call ptr @__kmpc_task_allow_completion_event(ptr {{.*}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
// CHECK: %[[conv:.*]] = ptrtoint ptr %[[return_val]] to i64
// CHECK: store i64 %[[conv]], ptr %[[event_handle]], align 4
// CHECK: call i32 @__kmpc_omp_task(ptr @{{.+}}, i32 %[[omp_global_thread_num]], ptr %[[task_data]])
omp.task detach(%event_handle : !llvm.ptr){
omp.terminator
}
llvm.return
}

// -----

// CHECK-LABEL: define void @omp_task
// CHECK-SAME: (i32 %[[x:.+]], i32 %[[y:.+]], ptr %[[zaddr:.+]])
llvm.func @omp_task(%x: i32, %y: i32, %zaddr: !llvm.ptr) {
Expand Down
Loading