Skip to content

Commit 55d6643

Browse files
[mlir][openmp] - Add the depend clause to omp.target and related offloading directives (#81081)
This patch adds support for the depend clause in a number of OpenMP directives/constructs related to offloading. Specifically, it adds the handling of the depend clause when it is used with the following constructs - target - target enter data - target update data - target exit data
1 parent 8456e0c commit 55d6643

File tree

5 files changed

+151
-13
lines changed

5 files changed

+151
-13
lines changed

flang/lib/Lower/OpenMP.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,7 +2825,8 @@ genEnterExitUpdateDataOp(Fortran::lower::AbstractConverter &converter,
28252825
directive);
28262826

28272827
return firOpBuilder.create<OpTy>(currentLocation, ifClauseOperand,
2828-
deviceOperand, nowaitAttr, mapOperands);
2828+
deviceOperand, nullptr, mlir::ValueRange(),
2829+
nowaitAttr, mapOperands);
28292830
}
28302831

28312832
// This functions creates a block for the body of the targetOp's region. It adds
@@ -3090,7 +3091,7 @@ genTargetOp(Fortran::lower::AbstractConverter &converter,
30903091

30913092
auto targetOp = converter.getFirOpBuilder().create<mlir::omp::TargetOp>(
30923093
currentLocation, ifClauseOperand, deviceOperand, threadLimitOperand,
3093-
nowaitAttr, mapOperands);
3094+
nullptr, mlir::ValueRange(), nowaitAttr, mapOperands);
30943095

30953096
genBodyOfTargetOp(converter, semaCtx, eval, genNested, targetOp, mapSymTypes,
30963097
mapSymLocs, mapSymbols, currentLocation);

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -781,7 +781,7 @@ def ClauseTaskDependInOut : I32EnumAttrCase<"taskdependinout", 2>;
781781

782782
def ClauseTaskDepend : I32EnumAttr<
783783
"ClauseTaskDepend",
784-
"task depend clause",
784+
"depend clause in a target or task construct",
785785
[ClauseTaskDependIn, ClauseTaskDependOut, ClauseTaskDependInOut]> {
786786
let genSpecializedAttr = 0;
787787
let cppNamespace = "::mlir::omp";
@@ -1447,11 +1447,17 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
14471447

14481448
The $map_types specifies the types and modifiers for the map clause.
14491449

1450-
TODO: depend clause and map_type_modifier values iterator and mapper.
1450+
The `depends` and `depend_vars` arguments are variadic lists of values
1451+
that specify the dependencies of this particular target task in relation to
1452+
other tasks.
1453+
1454+
TODO: map_type_modifier values iterator and mapper.
14511455
}];
14521456

14531457
let arguments = (ins Optional<I1>:$if_expr,
14541458
Optional<AnyInteger>:$device,
1459+
OptionalAttr<TaskDependArrayAttr>:$depends,
1460+
Variadic<OpenMP_PointerLikeType>:$depend_vars,
14551461
UnitAttr:$nowait,
14561462
Variadic<AnyType>:$map_operands);
14571463

@@ -1460,6 +1466,7 @@ def Target_EnterDataOp: OpenMP_Op<"target_enter_data",
14601466
| `device` `(` $device `:` type($device) `)`
14611467
| `nowait` $nowait
14621468
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
1469+
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
14631470
) attr-dict
14641471
}];
14651472

@@ -1494,11 +1501,17 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
14941501

14951502
The $map_types specifies the types and modifiers for the map clause.
14961503

1497-
TODO: depend clause and map_type_modifier values iterator and mapper.
1504+
The `depends` and `depend_vars` arguments are variadic lists of values
1505+
that specify the dependencies of this particular target task in relation to
1506+
other tasks.
1507+
1508+
TODO: map_type_modifier values iterator and mapper.
14981509
}];
14991510

15001511
let arguments = (ins Optional<I1>:$if_expr,
15011512
Optional<AnyInteger>:$device,
1513+
OptionalAttr<TaskDependArrayAttr>:$depends,
1514+
Variadic<OpenMP_PointerLikeType>:$depend_vars,
15021515
UnitAttr:$nowait,
15031516
Variadic<AnyType>:$map_operands);
15041517

@@ -1507,6 +1520,7 @@ def Target_ExitDataOp: OpenMP_Op<"target_exit_data",
15071520
| `device` `(` $device `:` type($device) `)`
15081521
| `nowait` $nowait
15091522
| `map_entries` `(` $map_operands `:` type($map_operands) `)`
1523+
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
15101524
) attr-dict
15111525
}];
15121526

@@ -1545,11 +1559,16 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
15451559
during verification to make sure the restrictions for target update are
15461560
respected.
15471561

1548-
TODO: depend clause
1562+
The `depends` and `depend_vars` arguments are variadic lists of values
1563+
that specify the dependencies of this particular target task in relation to
1564+
other tasks.
1565+
15491566
}];
15501567

15511568
let arguments = (ins Optional<I1>:$if_expr,
15521569
Optional<AnyInteger>:$device,
1570+
OptionalAttr<TaskDependArrayAttr>:$depends,
1571+
Variadic<OpenMP_PointerLikeType>:$depend_vars,
15531572
UnitAttr:$nowait,
15541573
Variadic<OpenMP_PointerLikeType>:$map_operands);
15551574

@@ -1558,6 +1577,7 @@ def Target_UpdateDataOp: OpenMP_Op<"target_update_data",
15581577
| `device` `(` $device `:` type($device) `)`
15591578
| `nowait` $nowait
15601579
| `motion_entries` `(` $map_operands `:` type($map_operands) `)`
1580+
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
15611581
) attr-dict
15621582
}];
15631583

@@ -1587,13 +1607,19 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
15871607
The optional $nowait elliminates the implicit barrier so the parent task can make progress
15881608
even if the target task is not yet completed.
15891609

1590-
TODO: is_device_ptr, depend, defaultmap, in_reduction
1610+
The `depends` and `depend_vars` arguments are variadic lists of values
1611+
that specify the dependencies of this particular target task in relation to
1612+
other tasks.
1613+
1614+
TODO: is_device_ptr, defaultmap, in_reduction
15911615

15921616
}];
15931617

15941618
let arguments = (ins Optional<I1>:$if_expr,
15951619
Optional<AnyInteger>:$device,
15961620
Optional<AnyInteger>:$thread_limit,
1621+
OptionalAttr<TaskDependArrayAttr>:$depends,
1622+
Variadic<OpenMP_PointerLikeType>:$depend_vars,
15971623
UnitAttr:$nowait,
15981624
Variadic<AnyType>:$map_operands);
15991625

@@ -1605,6 +1631,7 @@ def TargetOp : OpenMP_Op<"target",[IsolatedFromAbove, MapClauseOwningOpInterface
16051631
| `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
16061632
| `nowait` $nowait
16071633
| `map_entries` `(` custom<MapEntries>($map_operands, type($map_operands)) `)`
1634+
| `depend` `(` custom<DependVarList>($depend_vars, type($depend_vars), $depends) `)`
16081635
) $region attr-dict
16091636
}];
16101637

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ static LogicalResult verifyDependVarList(Operation *op,
628628
return op->emitOpError() << "expected as many depend values"
629629
" as depend variables";
630630
} else {
631-
if (depends)
631+
if (depends && !depends->empty())
632632
return op->emitOpError() << "unexpected depend values";
633633
return success();
634634
}
@@ -1032,19 +1032,31 @@ LogicalResult DataOp::verify() {
10321032
}
10331033

10341034
LogicalResult EnterDataOp::verify() {
1035-
return verifyMapClause(*this, getMapOperands());
1035+
LogicalResult verifyDependVars =
1036+
verifyDependVarList(*this, getDepends(), getDependVars());
1037+
return failed(verifyDependVars) ? verifyDependVars
1038+
: verifyMapClause(*this, getMapOperands());
10361039
}
10371040

10381041
LogicalResult ExitDataOp::verify() {
1039-
return verifyMapClause(*this, getMapOperands());
1042+
LogicalResult verifyDependVars =
1043+
verifyDependVarList(*this, getDepends(), getDependVars());
1044+
return failed(verifyDependVars) ? verifyDependVars
1045+
: verifyMapClause(*this, getMapOperands());
10401046
}
10411047

10421048
LogicalResult UpdateDataOp::verify() {
1043-
return verifyMapClause(*this, getMapOperands());
1049+
LogicalResult verifyDependVars =
1050+
verifyDependVarList(*this, getDepends(), getDependVars());
1051+
return failed(verifyDependVars) ? verifyDependVars
1052+
: verifyMapClause(*this, getMapOperands());
10441053
}
10451054

10461055
LogicalResult TargetOp::verify() {
1047-
return verifyMapClause(*this, getMapOperands());
1056+
LogicalResult verifyDependVars =
1057+
verifyDependVarList(*this, getDepends(), getDependVars());
1058+
return failed(verifyDependVars) ? verifyDependVars
1059+
: verifyMapClause(*this, getMapOperands());
10481060
}
10491061

10501062
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,15 @@ func.func @omp_target_enter_data(%map1: memref<?xi32>) {
16511651

16521652
// -----
16531653

1654+
func.func @omp_target_enter_data_depend(%a: memref<?xi32>) {
1655+
%0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
1656+
// expected-error @below {{op expected as many depend values as depend variables}}
1657+
omp.target_enter_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
1658+
return
1659+
}
1660+
1661+
// -----
1662+
16541663
func.func @omp_target_exit_data(%map1: memref<?xi32>) {
16551664
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32> {name = ""}
16561665
// expected-error @below {{from, release and delete map types are permitted}}
@@ -1660,6 +1669,15 @@ func.func @omp_target_exit_data(%map1: memref<?xi32>) {
16601669

16611670
// -----
16621671

1672+
func.func @omp_target_exit_data_depend(%a: memref<?xi32>) {
1673+
%0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
1674+
// expected-error @below {{op expected as many depend values as depend variables}}
1675+
omp.target_exit_data map_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
1676+
return
1677+
}
1678+
1679+
// -----
1680+
16631681
func.func @omp_target_update_invalid_motion_type(%map1 : memref<?xi32>) {
16641682
%mapv = omp.map_info var_ptr(%map1 : memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32> {name = ""}
16651683

@@ -1732,6 +1750,25 @@ llvm.mlir.global internal @_QFsubEx() : i32
17321750

17331751
// -----
17341752

1753+
func.func @omp_target_update_data_depend(%a: memref<?xi32>) {
1754+
%0 = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
1755+
// expected-error @below {{op expected as many depend values as depend variables}}
1756+
omp.target_update_data motion_entries(%0: memref<?xi32> ) {operandSegmentSizes = array<i32: 0, 0, 1, 0>}
1757+
return
1758+
}
1759+
1760+
// -----
1761+
1762+
func.func @omp_target_depend(%data_var: memref<i32>) {
1763+
// expected-error @below {{op expected as many depend values as depend variables}}
1764+
"omp.target"(%data_var) ({
1765+
"omp.terminator"() : () -> ()
1766+
}) {depends = [], operandSegmentSizes = array<i32: 0, 0, 0, 1, 0>} : (memref<i32>) -> ()
1767+
"func.return"() : () -> ()
1768+
}
1769+
1770+
// -----
1771+
17351772
func.func @omp_distribute(%data_var : memref<i32>) -> () {
17361773
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
17371774
"omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 0, 1, 0>}> ({

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %map1:
517517
"omp.target"(%if_cond, %device, %num_threads) ({
518518
// CHECK: omp.terminator
519519
omp.terminator
520-
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0>} : ( i1, si32, i32 ) -> ()
520+
}) {nowait, operandSegmentSizes = array<i32: 1,1,1,0,0>} : ( i1, si32, i32 ) -> ()
521521

522522
// Test with optional map clause.
523523
// CHECK: %[[MAP_A:.*]] = omp.map_info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(tofrom) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1717,6 +1717,18 @@ func.func @omp_task_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
17171717
return
17181718
}
17191719

1720+
1721+
// CHECK-LABEL: @omp_target_depend
1722+
// CHECK-SAME: (%arg0: memref<i32>, %arg1: memref<i32>) {
1723+
func.func @omp_target_depend(%arg0: memref<i32>, %arg1: memref<i32>) {
1724+
// CHECK: omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
1725+
omp.target depend(taskdependin -> %arg0 : memref<i32>, taskdependin -> %arg1 : memref<i32>, taskdependinout -> %arg0 : memref<i32>) {
1726+
// CHECK: omp.terminator
1727+
omp.terminator
1728+
} {operandSegmentSizes = array<i32: 0,0,0,3,0>}
1729+
return
1730+
}
1731+
17201732
func.func @omp_threadprivate() {
17211733
%0 = arith.constant 1 : i32
17221734
%1 = arith.constant 2 : i32
@@ -2145,3 +2157,52 @@ func.func @omp_targets_is_allocatable(%arg0: !llvm.ptr, %arg1: !llvm.ptr) -> ()
21452157
}
21462158
return
21472159
}
2160+
2161+
// CHECK-LABEL: func @omp_target_enter_update_exit_data_depend
2162+
// CHECK-SAME:([[ARG0:%.*]]: memref<?xi32>, [[ARG1:%.*]]: memref<?xi32>, [[ARG2:%.*]]: memref<?xi32>) {
2163+
func.func @omp_target_enter_update_exit_data_depend(%a: memref<?xi32>, %b: memref<?xi32>, %c: memref<?xi32>) {
2164+
// CHECK-NEXT: [[MAP0:%.*]] = omp.map_info
2165+
// CHECK-NEXT: [[MAP1:%.*]] = omp.map_info
2166+
// CHECK-NEXT: [[MAP2:%.*]] = omp.map_info
2167+
%map_a = omp.map_info var_ptr(%a: memref<?xi32>, tensor<?xi32>) map_clauses(to) capture(ByRef) -> memref<?xi32>
2168+
%map_b = omp.map_info var_ptr(%b: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
2169+
%map_c = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(exit_release_or_enter_alloc) capture(ByRef) -> memref<?xi32>
2170+
2171+
// Do some work on the host that writes to 'a'
2172+
omp.task depend(taskdependout -> %a : memref<?xi32>) {
2173+
"test.foo"(%a) : (memref<?xi32>) -> ()
2174+
omp.terminator
2175+
}
2176+
2177+
// Then map that over to the target
2178+
// CHECK: omp.target_enter_data nowait map_entries([[MAP0]], [[MAP2]] : memref<?xi32>, memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
2179+
omp.target_enter_data nowait map_entries(%map_a, %map_c: memref<?xi32>, memref<?xi32>) depend(taskdependin -> %a: memref<?xi32>)
2180+
2181+
// Compute 'b' on the target and copy it back
2182+
// CHECK: omp.target map_entries([[MAP1]] -> {{%.*}} : memref<?xi32>) {
2183+
omp.target map_entries(%map_b -> %arg0 : memref<?xi32>) {
2184+
^bb0(%arg0: memref<?xi32>) :
2185+
"test.foo"(%arg0) : (memref<?xi32>) -> ()
2186+
omp.terminator
2187+
}
2188+
2189+
// Update 'a' on the host using 'b'
2190+
omp.task depend(taskdependout -> %a: memref<?xi32>){
2191+
"test.bar"(%a, %b) : (memref<?xi32>, memref<?xi32>) -> ()
2192+
}
2193+
2194+
// Copy the updated 'a' onto the target
2195+
// CHECK: omp.target_update_data nowait motion_entries([[MAP0]] : memref<?xi32>) depend(taskdependin -> [[ARG0]] : memref<?xi32>)
2196+
omp.target_update_data motion_entries(%map_a : memref<?xi32>) depend(taskdependin -> %a : memref<?xi32>) nowait
2197+
2198+
// Compute 'c' on the target and copy it back
2199+
%map_c_from = omp.map_info var_ptr(%c: memref<?xi32>, tensor<?xi32>) map_clauses(from) capture(ByRef) -> memref<?xi32>
2200+
omp.target map_entries(%map_a -> %arg0, %map_c_from -> %arg1 : memref<?xi32>, memref<?xi32>) depend(taskdependout -> %c : memref<?xi32>) {
2201+
^bb0(%arg0 : memref<?xi32>, %arg1 : memref<?xi32>) :
2202+
"test.foobar"() : ()->()
2203+
omp.terminator
2204+
}
2205+
// CHECK: omp.target_exit_data map_entries([[MAP2]] : memref<?xi32>) depend(taskdependin -> [[ARG2]] : memref<?xi32>)
2206+
omp.target_exit_data map_entries(%map_c : memref<?xi32>) depend(taskdependin -> %c : memref<?xi32>)
2207+
return
2208+
}

0 commit comments

Comments
 (0)