Skip to content

Commit 2784060

Browse files
authored
[MLIR][Flang][OpenMP] Remove omp.parallel from loop wrapper ops (#105833)
This patch updates the `omp.parallel` operation according to the results of the discussion in [this RFC](https://discourse.llvm.org/t/rfc-disambiguation-between-loop-and-block-associated-omp-parallelop/79972). It is removed from the set of loop wrapper operations, changing the expected MLIR representation for composite `distribute parallel do/for` into the following: ```mlir omp.parallel { ... omp.distribute { omp.wsloop { omp.loop_nest ... { ... } omp.terminator } omp.terminator } ... omp.terminator } ``` MLIR verifiers for operations impacted by this representation change are updated, as well as related tests. The `LoopWrapperInterface` is also updated, since it's no longer representing an optional "role" of an operation but a mandatory set of restrictions instead.
1 parent 3ef37e2 commit 2784060

File tree

6 files changed

+149
-181
lines changed

6 files changed

+149
-181
lines changed

flang/lib/Lower/OpenMP/DataSharingProcessor.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,7 @@ void DataSharingProcessor::insertBarrier() {
231231
void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
232232
mlir::omp::LoopNestOp loopOp;
233233
if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
234-
loopOp = wrapper.isWrapper()
235-
? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
236-
: nullptr;
234+
loopOp = mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop());
237235

238236
bool cmpCreated = false;
239237
mlir::OpBuilder::InsertionGuard guard(firOpBuilder);

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>
129129
def ParallelOp : OpenMP_Op<"parallel", traits = [
130130
AttrSizedOperandSegments, AutomaticAllocationScope,
131131
DeclareOpInterfaceMethods<ComposableOpInterface>,
132-
DeclareOpInterfaceMethods<LoopWrapperInterface>,
133132
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
134133
RecursiveMemoryEffects
135134
], clauses = [

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 6 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -71,69 +71,35 @@ def ReductionClauseInterface : OpInterface<"ReductionClauseInterface"> {
7171

7272
def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
7373
let description = [{
74-
OpenMP operations that can wrap a single loop nest. When taking a wrapper
75-
role, these operations must only contain a single region with a single block
76-
in which there's a single operation and a terminator. That nested operation
77-
must be another loop wrapper or an `omp.loop_nest`.
74+
OpenMP operations that wrap a single loop nest. They must only contain a
75+
single region with a single block in which there's a single operation and a
76+
terminator. That nested operation must be another loop wrapper or an
77+
`omp.loop_nest`.
7878
}];
7979

8080
let cppNamespace = "::mlir::omp";
8181

8282
let methods = [
83-
InterfaceMethod<
84-
/*description=*/[{
85-
Tell whether the operation could be taking the role of a loop wrapper.
86-
That is, it has a single region with a single block in which there are
87-
two operations: another wrapper (also taking a loop wrapper role) or
88-
`omp.loop_nest` operation and a terminator.
89-
}],
90-
/*retTy=*/"bool",
91-
/*methodName=*/"isWrapper",
92-
(ins ), [{}], [{
93-
if ($_op->getNumRegions() != 1)
94-
return false;
95-
96-
Region &r = $_op->getRegion(0);
97-
if (!r.hasOneBlock())
98-
return false;
99-
100-
if (::llvm::range_size(r.getOps()) != 2)
101-
return false;
102-
103-
Operation &firstOp = *r.op_begin();
104-
Operation &secondOp = *(std::next(r.op_begin()));
105-
106-
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
107-
return false;
108-
109-
if (auto wrapper = ::llvm::dyn_cast<LoopWrapperInterface>(firstOp))
110-
return wrapper.isWrapper();
111-
112-
return ::llvm::isa<LoopNestOp>(firstOp);
113-
}]
114-
>,
11583
InterfaceMethod<
11684
/*description=*/[{
11785
If there is another loop wrapper immediately nested inside, return that
118-
operation. Assumes this operation is taking a loop wrapper role.
86+
operation. Assumes this operation is a valid loop wrapper.
11987
}],
12088
/*retTy=*/"::mlir::omp::LoopWrapperInterface",
12189
/*methodName=*/"getNestedWrapper",
12290
(ins), [{}], [{
123-
assert($_op.isWrapper() && "Unexpected non-wrapper op");
12491
Operation *nested = &*$_op->getRegion(0).op_begin();
12592
return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
12693
}]
12794
>,
12895
InterfaceMethod<
12996
/*description=*/[{
13097
Return the loop nest nested directly or indirectly inside of this loop
131-
wrapper. Assumes this operation is taking a loop wrapper role.
98+
wrapper. Assumes this operation is a valid loop wrapper.
13299
}],
133100
/*retTy=*/"::mlir::Operation *",
134101
/*methodName=*/"getWrappedLoop",
135102
(ins), [{}], [{
136-
assert($_op.isWrapper() && "Unexpected non-wrapper op");
137103
if (LoopWrapperInterface nested = $_op.getNestedWrapper())
138104
return nested.getWrappedLoop();
139105
return &*$_op->getRegion(0).op_begin();

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 59 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1541,26 +1541,25 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
15411541
}
15421542

15431543
LogicalResult ParallelOp::verify() {
1544-
// Check that it is a valid loop wrapper if it's taking that role.
1545-
if (isa<DistributeOp>((*this)->getParentOp())) {
1546-
if (!isWrapper())
1547-
return emitOpError() << "must take a loop wrapper role if nested inside "
1548-
"of 'omp.distribute'";
1544+
auto distributeChildOps = getOps<DistributeOp>();
1545+
if (!distributeChildOps.empty()) {
15491546
if (!isComposite())
15501547
return emitError()
1551-
<< "'omp.composite' attribute missing from composite wrapper";
1548+
<< "'omp.composite' attribute missing from composite operation";
15521549

1553-
if (LoopWrapperInterface nested = getNestedWrapper()) {
1554-
// Check for the allowed leaf constructs that may appear in a composite
1555-
// construct directly after PARALLEL.
1556-
if (!isa<WsloopOp>(nested))
1557-
return emitError() << "only supported nested wrapper is 'omp.wsloop'";
1558-
} else {
1559-
return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
1550+
auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
1551+
Operation &distributeOp = **distributeChildOps.begin();
1552+
for (Operation &childOp : getOps()) {
1553+
if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
1554+
continue;
1555+
1556+
if (!childOp.hasTrait<OpTrait::IsTerminator>())
1557+
return emitError() << "unexpected OpenMP operation inside of composite "
1558+
"'omp.parallel'";
15601559
}
15611560
} else if (isComposite()) {
15621561
return emitError()
1563-
<< "'omp.composite' attribute present in non-composite wrapper";
1562+
<< "'omp.composite' attribute present in non-composite operation";
15641563
}
15651564

15661565
if (getAllocateVars().size() != getAllocatorVars().size())
@@ -1721,6 +1720,32 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
17211720
p.printRegion(region, /*printEntryBlockArgs=*/false);
17221721
}
17231722

1723+
static LogicalResult verifyLoopWrapperInterface(Operation *op) {
1724+
if (op->getNumRegions() != 1)
1725+
return op->emitOpError() << "loop wrapper contains multiple regions";
1726+
1727+
Region &region = op->getRegion(0);
1728+
if (!region.hasOneBlock())
1729+
return op->emitOpError() << "loop wrapper contains multiple blocks";
1730+
1731+
if (::llvm::range_size(region.getOps()) != 2)
1732+
return op->emitOpError()
1733+
<< "loop wrapper does not contain exactly two nested ops";
1734+
1735+
Operation &firstOp = *region.op_begin();
1736+
Operation &secondOp = *(std::next(region.op_begin()));
1737+
1738+
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
1739+
return op->emitOpError()
1740+
<< "second nested op in loop wrapper is not a terminator";
1741+
1742+
if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
1743+
return op->emitOpError() << "first nested op in loop wrapper is not "
1744+
"another loop wrapper or `omp.loop_nest`";
1745+
1746+
return success();
1747+
}
1748+
17241749
void WsloopOp::build(OpBuilder &builder, OperationState &state,
17251750
ArrayRef<NamedAttribute> attributes) {
17261751
build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
@@ -1751,15 +1776,12 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
17511776
}
17521777

17531778
LogicalResult WsloopOp::verify() {
1754-
if (!isWrapper())
1755-
return emitOpError() << "must be a loop wrapper";
1779+
if (verifyLoopWrapperInterface(*this).failed())
1780+
return failure();
17561781

1757-
auto wrapper =
1758-
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
17591782
bool isCompositeChildLeaf =
1760-
wrapper && wrapper.isWrapper() &&
1761-
(!llvm::isa<ParallelOp>(wrapper) ||
1762-
llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));
1783+
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
1784+
17631785
if (LoopWrapperInterface nested = getNestedWrapper()) {
17641786
if (!isComposite())
17651787
return emitError()
@@ -1813,18 +1835,14 @@ LogicalResult SimdOp::verify() {
18131835
if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
18141836
return failure();
18151837

1816-
if (!isWrapper())
1817-
return emitOpError() << "must be a loop wrapper";
1838+
if (verifyLoopWrapperInterface(*this).failed())
1839+
return failure();
18181840

18191841
if (getNestedWrapper())
18201842
return emitOpError() << "must wrap an 'omp.loop_nest' directly";
18211843

1822-
auto wrapper =
1823-
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
18241844
bool isCompositeChildLeaf =
1825-
wrapper && wrapper.isWrapper() &&
1826-
(!llvm::isa<ParallelOp>(wrapper) ||
1827-
llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));
1845+
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
18281846

18291847
if (!isComposite() && isCompositeChildLeaf)
18301848
return emitError()
@@ -1859,18 +1877,22 @@ LogicalResult DistributeOp::verify() {
18591877
return emitError(
18601878
"expected equal sizes for allocate and allocator variables");
18611879

1862-
if (!isWrapper())
1863-
return emitOpError() << "must be a loop wrapper";
1880+
if (verifyLoopWrapperInterface(*this).failed())
1881+
return failure();
18641882

18651883
if (LoopWrapperInterface nested = getNestedWrapper()) {
18661884
if (!isComposite())
18671885
return emitError()
18681886
<< "'omp.composite' attribute missing from composite wrapper";
18691887
// Check for the allowed leaf constructs that may appear in a composite
18701888
// construct directly after DISTRIBUTE.
1871-
if (!isa<ParallelOp, SimdOp>(nested))
1872-
return emitError() << "only supported nested wrappers are 'omp.parallel' "
1873-
"and 'omp.simd'";
1889+
if (isa<WsloopOp>(nested)) {
1890+
if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
1891+
return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
1892+
"when 'omp.parallel' is the direct parent";
1893+
} else if (!isa<SimdOp>(nested))
1894+
return emitError() << "only supported nested wrappers are 'omp.simd' and "
1895+
"'omp.wsloop'";
18741896
} else if (isComposite()) {
18751897
return emitError()
18761898
<< "'omp.composite' attribute present in non-composite wrapper";
@@ -2063,8 +2085,8 @@ LogicalResult TaskloopOp::verify() {
20632085
"may not appear on the same taskloop directive");
20642086
}
20652087

2066-
if (!isWrapper())
2067-
return emitOpError() << "must be a loop wrapper";
2088+
if (verifyLoopWrapperInterface(*this).failed())
2089+
return failure();
20682090

20692091
if (LoopWrapperInterface nested = getNestedWrapper()) {
20702092
if (!isComposite())
@@ -2161,11 +2183,8 @@ LogicalResult LoopNestOp::verify() {
21612183
<< "range argument type does not match corresponding IV type";
21622184
}
21632185

2164-
auto wrapper =
2165-
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2166-
2167-
if (!wrapper || !wrapper.isWrapper())
2168-
return emitOpError() << "expects parent op to be a valid loop wrapper";
2186+
if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
2187+
return emitOpError() << "expects parent op to be a loop wrapper";
21692188

21702189
return success();
21712190
}
@@ -2175,8 +2194,6 @@ void LoopNestOp::gatherWrappers(
21752194
Operation *parent = (*this)->getParentOp();
21762195
while (auto wrapper =
21772196
llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
2178-
if (!wrapper.isWrapper())
2179-
break;
21802197
wrappers.push_back(wrapper);
21812198
parent = parent->getParentOp();
21822199
}

0 commit comments

Comments
 (0)