Skip to content

Commit cbfb069

Browse files
committed
Address review comments
1 parent 37792b4 commit cbfb069

File tree

3 files changed

+42
-44
lines changed

3 files changed

+42
-44
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -80,34 +80,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
8080
let cppNamespace = "::mlir::omp";
8181

8282
let methods = [
83-
InterfaceMethod<
84-
/*description=*/[{
85-
Check whether the operation is a valid loop wrapper. That is, it has a
86-
single region with a single block in which there are two operations:
87-
another loop wrapper or `omp.loop_nest` operation and a terminator.
88-
}],
89-
/*retTy=*/"bool",
90-
/*methodName=*/"isValidWrapper",
91-
(ins ), [{}], [{
92-
if ($_op->getNumRegions() != 1)
93-
return false;
94-
95-
Region &r = $_op->getRegion(0);
96-
if (!r.hasOneBlock())
97-
return false;
98-
99-
if (::llvm::range_size(r.getOps()) != 2)
100-
return false;
101-
102-
Operation &firstOp = *r.op_begin();
103-
Operation &secondOp = *(std::next(r.op_begin()));
104-
105-
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
106-
return false;
107-
108-
return ::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp);
109-
}]
110-
>,
11183
InterfaceMethod<
11284
/*description=*/[{
11385
If there is another loop wrapper immediately nested inside, return that
@@ -116,7 +88,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
11688
/*retTy=*/"::mlir::omp::LoopWrapperInterface",
11789
/*methodName=*/"getNestedWrapper",
11890
(ins), [{}], [{
119-
assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
12091
Operation *nested = &*$_op->getRegion(0).op_begin();
12192
return ::llvm::dyn_cast<LoopWrapperInterface>(nested);
12293
}]
@@ -129,7 +100,6 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
129100
/*retTy=*/"::mlir::Operation *",
130101
/*methodName=*/"getWrappedLoop",
131102
(ins), [{}], [{
132-
assert($_op.isValidWrapper() && "Unexpected non-wrapper op");
133103
if (LoopWrapperInterface nested = $_op.getNestedWrapper())
134104
return nested.getWrappedLoop();
135105
return &*$_op->getRegion(0).op_begin();

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1720,6 +1720,32 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
17201720
p.printRegion(region, /*printEntryBlockArgs=*/false);
17211721
}
17221722

1723+
static LogicalResult verifyLoopWrapperInterface(Operation *op) {
1724+
if (op->getNumRegions() != 1)
1725+
return op->emitOpError() << "loop wrapper contains multiple regions";
1726+
1727+
Region &region = op->getRegion(0);
1728+
if (!region.hasOneBlock())
1729+
return op->emitOpError() << "loop wrapper contains multiple blocks";
1730+
1731+
if (::llvm::range_size(region.getOps()) != 2)
1732+
return op->emitOpError()
1733+
<< "loop wrapper does not contain exactly two nested ops";
1734+
1735+
Operation &firstOp = *region.op_begin();
1736+
Operation &secondOp = *(std::next(region.op_begin()));
1737+
1738+
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
1739+
return op->emitOpError()
1740+
<< "second nested op in loop wrapper is not a terminator";
1741+
1742+
if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
1743+
return op->emitOpError() << "first nested op in loop wrapper is not "
1744+
"another loop wrapper or `omp.loop_nest`";
1745+
1746+
return success();
1747+
}
1748+
17231749
void WsloopOp::build(OpBuilder &builder, OperationState &state,
17241750
ArrayRef<NamedAttribute> attributes) {
17251751
build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
@@ -1750,8 +1776,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
17501776
}
17511777

17521778
LogicalResult WsloopOp::verify() {
1753-
if (!isValidWrapper())
1754-
return emitOpError() << "must be a valid loop wrapper";
1779+
if (verifyLoopWrapperInterface(*this).failed())
1780+
return failure();
17551781

17561782
bool isCompositeChildLeaf =
17571783
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
@@ -1809,8 +1835,8 @@ LogicalResult SimdOp::verify() {
18091835
if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
18101836
return failure();
18111837

1812-
if (!isValidWrapper())
1813-
return emitOpError() << "must be a valid loop wrapper";
1838+
if (verifyLoopWrapperInterface(*this).failed())
1839+
return failure();
18141840

18151841
if (getNestedWrapper())
18161842
return emitOpError() << "must wrap an 'omp.loop_nest' directly";
@@ -1851,8 +1877,8 @@ LogicalResult DistributeOp::verify() {
18511877
return emitError(
18521878
"expected equal sizes for allocate and allocator variables");
18531879

1854-
if (!isValidWrapper())
1855-
return emitOpError() << "must be a valid loop wrapper";
1880+
if (verifyLoopWrapperInterface(*this).failed())
1881+
return failure();
18561882

18571883
if (LoopWrapperInterface nested = getNestedWrapper()) {
18581884
if (!isComposite())
@@ -2059,8 +2085,8 @@ LogicalResult TaskloopOp::verify() {
20592085
"may not appear on the same taskloop directive");
20602086
}
20612087

2062-
if (!isValidWrapper())
2063-
return emitOpError() << "must be a valid loop wrapper";
2088+
if (verifyLoopWrapperInterface(*this).failed())
2089+
return failure();
20642090

20652091
if (LoopWrapperInterface nested = getNestedWrapper()) {
20662092
if (!isComposite())

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ func.func @iv_number_mismatch(%lb : index, %ub : index, %step : index) {
123123
// -----
124124

125125
func.func @no_wrapper(%lb : index, %ub : index, %step : index) {
126-
// expected-error @below {{op must be a valid loop wrapper}}
126+
// expected-error @below {{op loop wrapper does not contain exactly two nested ops}}
127127
omp.wsloop {
128128
%0 = arith.constant 0 : i32
129129
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
@@ -309,7 +309,7 @@ llvm.func @test_omp_wsloop_dynamic_wrong_modifier3(%lb : i64, %ub : i64, %step :
309309
// -----
310310

311311
func.func @omp_simd() -> () {
312-
// expected-error @below {{op must be a valid loop wrapper}}
312+
// expected-error @below {{op loop wrapper does not contain exactly two nested ops}}
313313
omp.simd {
314314
omp.terminator
315315
}
@@ -1963,7 +1963,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
19631963
// -----
19641964

19651965
func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
1966-
// expected-error @below {{op must be a valid loop wrapper}}
1966+
// expected-error @below {{op first nested op in loop wrapper is not another loop wrapper or `omp.loop_nest`}}
19671967
omp.taskloop {
19681968
%0 = arith.constant 0 : i32
19691969
omp.terminator
@@ -2171,11 +2171,13 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
21712171

21722172
// -----
21732173

2174-
func.func @omp_distribute_wrapper() -> () {
2175-
// expected-error @below {{op must be a valid loop wrapper}}
2174+
func.func @omp_distribute_wrapper(%lb: index, %ub: index, %step: index) -> () {
2175+
// expected-error @below {{op second nested op in loop wrapper is not a terminator}}
21762176
omp.distribute {
2177+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2178+
"omp.yield"() : () -> ()
2179+
}
21772180
%0 = arith.constant 0 : i32
2178-
"omp.terminator"() : () -> ()
21792181
}
21802182
}
21812183

0 commit comments

Comments
 (0)