Skip to content

[mlir][openmp] - Add the depend clause to omp.target and related offloading directives #81081

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions flang/lib/Lower/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,7 +2765,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
directive);

return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
deviceOperand, nowaitAttr, mapOperands);
deviceOperand, nullptr, mlir::ValueRange(),
nowaitAttr, mapOperands);
}

// This functions creates a block for the body of the targetOp's region. It adds
Expand Down Expand Up @@ -3029,7 +3030,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,

auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
nowaitAttr, mapOperands);
nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);

genBodyOfTargetOp(converter, eval, genNested, targetOp, mapSymTypes,
mapSymLocs, mapSymbols, currentLocation);
Expand Down
37 changes: 32 additions & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;

def ClauseTaskDepend : I32EnumAttr<
"ClauseTaskDepend",
"task depend clause",
"depend clause in a target or task construct",
[ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::omp";
Expand Down Expand Up @@ -1359,11 +1359,17 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",

The $map_types specifies the types and modifiers for the map clause.

TODO: depend clause and map_type_modifier values iterator and mapper.
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: map_type_modifier values iterator and mapper.
}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);

Expand All @@ -1372,6 +1378,7 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];

Expand Down Expand Up @@ -1406,11 +1413,17 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",

The $map_types specifies the types and modifiers for the map clause.

TODO: depend clause and map_type_modifier values iterator and mapper.
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: map_type_modifier values iterator and mapper.
}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);

Expand All @@ -1419,6 +1432,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];

Expand Down Expand Up @@ -1457,11 +1471,16 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
during verification to make sure the restrictions for target update are
respected.

TODO: depend clause
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<OpenMP_PointerLikeType>:$map_operands);

Expand All @@ -1470,6 +1489,7 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
| `device` `(` $device `:` type($device) `)`
| `nowait` $nowait
| `motion_entries` `(` $map_operands `:` type($map_operands) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) attr-dict
}];

Expand Down Expand Up @@ -1499,13 +1519,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
The optional $nowait elliminates the implicit barrier so the parent task can make progress
even if the target task is not yet completed.

TODO: is_device_ptr, depend, defaultmap, in_reduction
The `depends` and `depend_vars` arguments are variadic lists of values
that specify the dependencies of this particular target task in relation to
other tasks.

TODO: is_device_ptr, defaultmap, in_reduction

}];

let arguments = (ins Optional<I1>:$if_expr,
Optional<AnyInteger>:$device,
Optional<AnyInteger>:$thread_limit,
OptionalAttr<TaskDependArrayAttr>:$depends,
Variadic<OpenMP_PointerLikeType>:$depend_vars,
UnitAttr:$nowait,
Variadic<AnyType>:$map_operands);

Expand All @@ -1517,6 +1543,7 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
| `nowait` $nowait
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
) $region attr-dict
}];

Expand Down
22 changes: 17 additions & 5 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ static LogicalResult verifyDependVarList(Operation *op,
return op->emitOpError() << "expected as many depend values"
" as depend variables";
} else {
if (depends)
if (depends && !depends->empty())
return op->emitOpError() << "unexpected depend values";
return success();
}
Expand Down Expand Up @@ -965,19 +965,31 @@ LogicalResult DataOp::verify() {
}

LogicalResult EnterDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMapOperands());
}

LogicalResult ExitDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMapOperands());
}

LogicalResult UpdateDataOp::verify() {
return verifyMapClause(*this, getMapOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMapOperands());
}

LogicalResult TargetOp::verify() {
return verifyMapClause(*this, getMapOperands());
LogicalResult verifyDependVars =
verifyDependVarList(*this, getDepends(), getDependVars());
return failed(verifyDependVars) ? verifyDependVars
: verifyMapClause(*this, getMapOperands());
}

//===----------------------------------------------------------------------===//
Expand Down
37 changes: 37 additions & 0 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1651,6 +1651,15 @@ func.func @omp_target_enter_data(%map1: memref<?xi32>) {

// -----

func.func @omp_target_enter_data_depend(%a: memref<?xi32>) {
%0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
// expected-error @below {{op expected as many depend values as depend variables}}
omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
return
}

// -----

func.func @omp_target_exit_data(%map1: memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
// expected-error @below {{from, release and delete map types are permitted}}
Expand All @@ -1660,6 +1669,15 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {

// -----

func.func @omp_target_exit_data_depend(%a: memref<?xi32>) {
%0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
// expected-error @below {{op expected as many depend values as depend variables}}
omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
return
}

// -----

func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}

Expand Down Expand Up @@ -1732,6 +1750,25 @@ llvm.mlir.global internal @_QFsubEx() : i32

// -----

func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
%0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
// expected-error @below {{op expected as many depend values as depend variables}}
omp.target_update_data motion_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
return
}

// -----

func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}

// -----

func.func @omp_distribute(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({
Expand Down
63 changes: 62 additions & 1 deletion mlir/test/Dialect/OpenMP/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1:
"omp.target"(%if_cond, %device, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()

// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
Expand Down Expand Up @@ -1710,6 +1710,18 @@ func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
return
}


// CHECK-LABEL: @omp_target_depend
// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
// CHECK: omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
// CHECK: omp.terminator
omp.terminator
} {operandSegmentSizes = array<i32: 0,0,0,3,0>}
return
}

func.func @omp_threadprivate() {
%0 = arith.constant 1 : i32
%1 = arith.constant 2 : i32
Expand Down Expand Up @@ -2138,3 +2150,52 @@ func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
}
return
}

// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
%map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
%map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
%map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>

// Do some work on the host that writes to 'a'
omp.task depend(taskdependout -> %a : memref<?xi32>) {
"test.foo"(%a) : (memref<?xi32>) -> ()
omp.terminator
}

// Then map that over to the target
// CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)

// Compute 'b' on the target and copy it back
// CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
^bb0(%arg0: memref<?xi32>) :
"test.foo"(%arg0) : (memref<?xi32>) -> ()
omp.terminator
}

// Update 'a' on the host using 'b'
omp.task depend(taskdependout -> %a: memref<?xi32>){
"test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
}

// Copy the updated 'a' onto the target
// CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait

// Compute 'c' on the target and copy it back
%map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
"test.foobar"() : ()->()
omp.terminator
}
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
return
}