Skip to content

Commit aae08f4

Browse files
authored
[MLIR][OpenMP] Make omp.taskloop into a loop wrapper (#87253)
This patch updates the definition of `omp.taskloop` to enforce the restrictions of a wrapper operation.
1 parent 1ca6b44 commit aae08f4

File tree

5 files changed

+218
-159
lines changed

5 files changed

+218
-159
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,10 @@ using TaskgroupClauseOps =
284284
detail::Clauses<AllocateClauseOps, TaskReductionClauseOps>;
285285

286286
using TaskloopClauseOps =
287-
detail::Clauses<AllocateClauseOps, CollapseClauseOps, FinalClauseOps,
288-
GrainsizeClauseOps, IfClauseOps, InReductionClauseOps,
289-
LoopRelatedOps, MergeableClauseOps, NogroupClauseOps,
290-
NumTasksClauseOps, PriorityClauseOps, PrivateClauseOps,
291-
ReductionClauseOps, UntiedClauseOps>;
287+
detail::Clauses<AllocateClauseOps, FinalClauseOps, GrainsizeClauseOps,
288+
IfClauseOps, InReductionClauseOps, MergeableClauseOps,
289+
NogroupClauseOps, NumTasksClauseOps, PriorityClauseOps,
290+
PrivateClauseOps, ReductionClauseOps, UntiedClauseOps>;
292291

293292
using TaskwaitClauseOps = detail::Clauses<DependClauseOps, NowaitClauseOps>;
294293

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,32 +1030,30 @@ def TaskOp : OpenMP_Op<"task", [AttrSizedOperandSegments,
10301030
}
10311031

10321032
def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
1033-
AutomaticAllocationScope, RecursiveMemoryEffects,
1034-
AllTypesMatch<["lowerBound", "upperBound", "step"]>,
1033+
AutomaticAllocationScope,
10351034
DeclareOpInterfaceMethods<LoopWrapperInterface>,
1036-
ReductionClauseInterface]> {
1035+
RecursiveMemoryEffects, ReductionClauseInterface,
1036+
SingleBlockImplicitTerminator<"TerminatorOp">]> {
10371037
let summary = "taskloop construct";
10381038
let description = [{
10391039
The taskloop construct specifies that the iterations of one or more
10401040
associated loops will be executed in parallel using explicit tasks. The
10411041
iterations are distributed across tasks generated by the construct and
10421042
scheduled to be executed.
10431043

1044-
The `lowerBound` and `upperBound` specify a half-open range: the range
1045-
includes the lower bound but does not include the upper bound. If the
1046-
`inclusive` attribute is specified then the upper bound is also included.
1047-
The `step` specifies the loop step.
1048-
1049-
The body region can contain any number of blocks.
1044+
The body region can contain a single block which must contain a single
1045+
operation and a terminator. The operation must be another compatible loop
1046+
wrapper or an `omp.loop_nest`.
10501047

10511048
```
1052-
omp.taskloop <clauses>
1053-
for (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
1054-
%a = load %arrA[%i1, %i2] : memref<?x?xf32>
1055-
%b = load %arrB[%i1, %i2] : memref<?x?xf32>
1056-
%sum = arith.addf %a, %b : f32
1057-
store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
1058-
omp.terminator
1049+
omp.taskloop <clauses> {
1050+
omp.loop_nest (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
1051+
%a = load %arrA[%i1, %i2] : memref<?x?xf32>
1052+
%b = load %arrB[%i1, %i2] : memref<?x?xf32>
1053+
%sum = arith.addf %a, %b : f32
1054+
store %sum, %arrC[%i1, %i2] : memref<?x?xf32>
1055+
omp.yield
1056+
}
10591057
}
10601058
```
10611059

@@ -1132,11 +1130,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
11321130
created.
11331131
}];
11341132

1135-
let arguments = (ins Variadic<IntLikeType>:$lowerBound,
1136-
Variadic<IntLikeType>:$upperBound,
1137-
Variadic<IntLikeType>:$step,
1138-
UnitAttr:$inclusive,
1139-
Optional<I1>:$if_expr,
1133+
let arguments = (ins Optional<I1>:$if_expr,
11401134
Optional<I1>:$final_expr,
11411135
UnitAttr:$untied,
11421136
UnitAttr:$mergeable,
@@ -1179,8 +1173,7 @@ def TaskloopOp : OpenMP_Op<"taskloop", [AttrSizedOperandSegments,
11791173
|`grain_size` `(` $grain_size `:` type($grain_size) `)`
11801174
|`num_tasks` `(` $num_tasks `:` type($num_tasks) `)`
11811175
|`nogroup` $nogroup
1182-
) `for` custom<LoopControl>($region, $lowerBound, $upperBound, $step,
1183-
type($step), $inclusive) attr-dict
1176+
) $region attr-dict
11841177
}];
11851178

11861179
let extraClassDeclaration = [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1829,9 +1829,8 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
18291829
MLIRContext *ctx = builder.getContext();
18301830
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
18311831
TaskloopOp::build(
1832-
builder, state, clauses.loopLBVar, clauses.loopUBVar, clauses.loopStepVar,
1833-
clauses.loopInclusiveAttr, clauses.ifVar, clauses.finalVar,
1834-
clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars,
1832+
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
1833+
clauses.mergeableAttr, clauses.inReductionVars,
18351834
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
18361835
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
18371836
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
@@ -1870,6 +1869,16 @@ LogicalResult TaskloopOp::verify() {
18701869
"the grainsize clause and num_tasks clause are mutually exclusive and "
18711870
"may not appear on the same taskloop directive");
18721871
}
1872+
1873+
if (!isWrapper())
1874+
return emitOpError() << "must be a loop wrapper";
1875+
1876+
if (LoopWrapperInterface nested = getNestedWrapper()) {
1877+
// Check for the allowed leaf constructs that may appear in a composite
1878+
// construct directly after TASKLOOP.
1879+
if (!isa<SimdLoopOp>(nested))
1880+
return emitError() << "only supported nested wrapper is 'omp.simdloop'";
1881+
}
18731882
return success();
18741883
}
18751884

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 62 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,10 +1580,11 @@ func.func @omp_cancellationpoint2() {
15801580
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
15811581
%testmemref = "test.memref"() : () -> (memref<i32>)
15821582
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
1583-
"omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testmemref) ({
1584-
^bb0(%arg3: i32, %arg4: i32):
1585-
"omp.terminator"() : () -> ()
1586-
}) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 0, 0, 1, 0, 0, 0>} : (i32, i32, i32, i32, i32, i32, memref<i32>) -> ()
1583+
"omp.taskloop"(%testmemref) ({
1584+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1585+
omp.yield
1586+
}
1587+
}) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0, 0, 0>} : (memref<i32>) -> ()
15871588
return
15881589
}
15891590

@@ -1593,23 +1594,24 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
15931594
%testf32 = "test.f32"() : () -> (!llvm.ptr)
15941595
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
15951596
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
1596-
"omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32, %testf32_2) ({
1597-
^bb0(%arg3: i32, %arg4: i32):
1598-
"omp.terminator"() : () -> ()
1599-
}) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 2, 0, 0, 0, 0, 0>, reductions = [@add_f32]} : (i32, i32, i32, i32, i32, i32, !llvm.ptr, !llvm.ptr) -> ()
1597+
"omp.taskloop"(%testf32, %testf32_2) ({
1598+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1599+
omp.yield
1600+
}
1601+
}) {operandSegmentSizes = array<i32: 0, 0, 0, 2, 0, 0, 0, 0, 0>, reductions = [@add_f32]} : (!llvm.ptr, !llvm.ptr) -> ()
16001602
return
16011603
}
16021604

16031605
// -----
16041606

16051607
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
16061608
%testf32 = "test.f32"() : () -> (!llvm.ptr)
1607-
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
16081609
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
1609-
"omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32) ({
1610-
^bb0(%arg3: i32, %arg4: i32):
1611-
"omp.terminator"() : () -> ()
1612-
}) {operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 0, 1, 0, 0, 0, 0, 0>, reductions = [@add_f32, @add_f32]} : (i32, i32, i32, i32, i32, i32, !llvm.ptr) -> ()
1610+
"omp.taskloop"(%testf32) ({
1611+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1612+
omp.yield
1613+
}
1614+
}) {operandSegmentSizes = array<i32: 0, 0, 0, 1, 0, 0, 0, 0, 0>, reductions = [@add_f32, @add_f32]} : (!llvm.ptr) -> ()
16131615
return
16141616
}
16151617

@@ -1619,23 +1621,24 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
16191621
%testf32 = "test.f32"() : () -> (!llvm.ptr)
16201622
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
16211623
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
1622-
"omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32, %testf32_2) ({
1623-
^bb0(%arg3: i32, %arg4: i32):
1624-
"omp.terminator"() : () -> ()
1625-
}) {in_reductions = [@add_f32], operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 2, 0, 0, 0, 0, 0, 0>} : (i32, i32, i32, i32, i32, i32, !llvm.ptr, !llvm.ptr) -> ()
1624+
"omp.taskloop"(%testf32, %testf32_2) ({
1625+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1626+
omp.yield
1627+
}
1628+
}) {in_reductions = [@add_f32], operandSegmentSizes = array<i32: 0, 0, 2, 0, 0, 0, 0, 0, 0>} : (!llvm.ptr, !llvm.ptr) -> ()
16261629
return
16271630
}
16281631

16291632
// -----
16301633

16311634
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
16321635
%testf32 = "test.f32"() : () -> (!llvm.ptr)
1633-
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
16341636
// expected-error @below {{expected as many reduction symbol references as reduction variables}}
1635-
"omp.taskloop"(%lb, %ub, %ub, %lb, %step, %step, %testf32_2) ({
1636-
^bb0(%arg3: i32, %arg4: i32):
1637-
"omp.terminator"() : () -> ()
1638-
}) {in_reductions = [@add_f32, @add_f32], operandSegmentSizes = array<i32: 2, 2, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0>} : (i32, i32, i32, i32, i32, i32, !llvm.ptr) -> ()
1637+
"omp.taskloop"(%testf32) ({
1638+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1639+
omp.yield
1640+
}
1641+
}) {in_reductions = [@add_f32, @add_f32], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0>} : (!llvm.ptr) -> ()
16391642
return
16401643
}
16411644

@@ -1657,9 +1660,10 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
16571660
%testf32 = "test.f32"() : () -> (!llvm.ptr)
16581661
%testf32_2 = "test.f32"() : () -> (!llvm.ptr)
16591662
// expected-error @below {{if a reduction clause is present on the taskloop directive, the nogroup clause must not be specified}}
1660-
omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) nogroup
1661-
for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1662-
omp.terminator
1663+
omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr, @add_f32 -> %testf32_2 : !llvm.ptr) nogroup {
1664+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1665+
omp.yield
1666+
}
16631667
}
16641668
return
16651669
}
@@ -1681,9 +1685,10 @@ combiner {
16811685
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
16821686
%testf32 = "test.f32"() : () -> (!llvm.ptr)
16831687
// expected-error @below {{the same list item cannot appear in both a reduction and an in_reduction clause}}
1684-
omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr) in_reduction(@add_f32 -> %testf32 : !llvm.ptr)
1685-
for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1686-
omp.terminator
1688+
omp.taskloop reduction(@add_f32 -> %testf32 : !llvm.ptr) in_reduction(@add_f32 -> %testf32 : !llvm.ptr) {
1689+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1690+
omp.yield
1691+
}
16871692
}
16881693
return
16891694
}
@@ -1693,15 +1698,42 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
16931698
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
16941699
%testi64 = "test.i64"() : () -> (i64)
16951700
// expected-error @below {{the grainsize clause and num_tasks clause are mutually exclusive and may not appear on the same taskloop directive}}
1696-
omp.taskloop grain_size(%testi64: i64) num_tasks(%testi64: i64)
1697-
for (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1701+
omp.taskloop grain_size(%testi64: i64) num_tasks(%testi64: i64) {
1702+
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
1703+
omp.yield
1704+
}
1705+
}
1706+
return
1707+
}
1708+
1709+
// -----
1710+
1711+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
1712+
// expected-error @below {{op must be a loop wrapper}}
1713+
omp.taskloop {
1714+
%0 = arith.constant 0 : i32
16981715
omp.terminator
16991716
}
17001717
return
17011718
}
17021719

17031720
// -----
17041721

1722+
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
1723+
// expected-error @below {{only supported nested wrapper is 'omp.simdloop'}}
1724+
omp.taskloop {
1725+
omp.distribute {
1726+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
1727+
omp.yield
1728+
}
1729+
omp.terminator
1730+
}
1731+
}
1732+
return
1733+
}
1734+
1735+
// -----
1736+
17051737
func.func @omp_threadprivate() {
17061738
%1 = llvm.mlir.addressof @_QFsubEx : !llvm.ptr
17071739
// expected-error @below {{op failed to verify that all of {sym_addr, tls_addr} have same type}}

0 commit comments

Comments
 (0)