Skip to content

Commit 20f4ade

Browse files
authored
[MLIR][OpenMP] Move loop wrapper verification to the interface (NFC) (#110505)
This patch moves verification code for the `LoopWrapperInterface` to the interface itself, checking it automatically for each operation that has that interface.
1 parent 4ae0c50 commit 20f4ade

File tree

3 files changed

+76
-48
lines changed

3 files changed

+76
-48
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,15 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
106106
}]
107107
>
108108
];
109+
110+
let extraClassDeclaration = [{
111+
/// Interface verifier imlementation.
112+
llvm::LogicalResult verifyImpl();
113+
}];
114+
115+
let verify = [{
116+
return ::llvm::cast<::mlir::omp::LoopWrapperInterface>($_op).verifyImpl();
117+
}];
109118
}
110119

111120
def ComposableOpInterface : OpInterface<"ComposableOpInterface"> {

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 31 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,37 @@ LogicalResult SingleOp::verify() {
16821682
getCopyprivateSyms());
16831683
}
16841684

1685+
//===----------------------------------------------------------------------===//
1686+
// LoopWrapperInterface
1687+
//===----------------------------------------------------------------------===//
1688+
1689+
LogicalResult LoopWrapperInterface::verifyImpl() {
1690+
Operation *op = this->getOperation();
1691+
if (op->getNumRegions() != 1)
1692+
return emitOpError() << "loop wrapper contains multiple regions";
1693+
1694+
Region &region = op->getRegion(0);
1695+
if (!region.hasOneBlock())
1696+
return emitOpError() << "loop wrapper contains multiple blocks";
1697+
1698+
if (::llvm::range_size(region.getOps()) != 2)
1699+
return emitOpError()
1700+
<< "loop wrapper does not contain exactly two nested ops";
1701+
1702+
Operation &firstOp = *region.op_begin();
1703+
Operation &secondOp = *(std::next(region.op_begin()));
1704+
1705+
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
1706+
return emitOpError()
1707+
<< "second nested op in loop wrapper is not a terminator";
1708+
1709+
if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
1710+
return emitOpError() << "first nested op in loop wrapper is not "
1711+
"another loop wrapper or `omp.loop_nest`";
1712+
1713+
return success();
1714+
}
1715+
16851716
//===----------------------------------------------------------------------===//
16861717
// WsloopOp
16871718
//===----------------------------------------------------------------------===//
@@ -1714,32 +1745,6 @@ void printWsloop(OpAsmPrinter &p, Operation *op, Region &region,
17141745
p.printRegion(region, /*printEntryBlockArgs=*/false);
17151746
}
17161747

1717-
static LogicalResult verifyLoopWrapperInterface(Operation *op) {
1718-
if (op->getNumRegions() != 1)
1719-
return op->emitOpError() << "loop wrapper contains multiple regions";
1720-
1721-
Region &region = op->getRegion(0);
1722-
if (!region.hasOneBlock())
1723-
return op->emitOpError() << "loop wrapper contains multiple blocks";
1724-
1725-
if (::llvm::range_size(region.getOps()) != 2)
1726-
return op->emitOpError()
1727-
<< "loop wrapper does not contain exactly two nested ops";
1728-
1729-
Operation &firstOp = *region.op_begin();
1730-
Operation &secondOp = *(std::next(region.op_begin()));
1731-
1732-
if (!secondOp.hasTrait<OpTrait::IsTerminator>())
1733-
return op->emitOpError()
1734-
<< "second nested op in loop wrapper is not a terminator";
1735-
1736-
if (!::llvm::isa<LoopNestOp, LoopWrapperInterface>(firstOp))
1737-
return op->emitOpError() << "first nested op in loop wrapper is not "
1738-
"another loop wrapper or `omp.loop_nest`";
1739-
1740-
return success();
1741-
}
1742-
17431748
void WsloopOp::build(OpBuilder &builder, OperationState &state,
17441749
ArrayRef<NamedAttribute> attributes) {
17451750
build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
@@ -1770,9 +1775,6 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
17701775
}
17711776

17721777
LogicalResult WsloopOp::verify() {
1773-
if (verifyLoopWrapperInterface(*this).failed())
1774-
return failure();
1775-
17761778
bool isCompositeChildLeaf =
17771779
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
17781780

@@ -1829,9 +1831,6 @@ LogicalResult SimdOp::verify() {
18291831
if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
18301832
return failure();
18311833

1832-
if (verifyLoopWrapperInterface(*this).failed())
1833-
return failure();
1834-
18351834
if (getNestedWrapper())
18361835
return emitOpError() << "must wrap an 'omp.loop_nest' directly";
18371836

@@ -1871,9 +1870,6 @@ LogicalResult DistributeOp::verify() {
18711870
return emitError(
18721871
"expected equal sizes for allocate and allocator variables");
18731872

1874-
if (verifyLoopWrapperInterface(*this).failed())
1875-
return failure();
1876-
18771873
if (LoopWrapperInterface nested = getNestedWrapper()) {
18781874
if (!isComposite())
18791875
return emitError()
@@ -2079,9 +2075,6 @@ LogicalResult TaskloopOp::verify() {
20792075
"may not appear on the same taskloop directive");
20802076
}
20812077

2082-
if (verifyLoopWrapperInterface(*this).failed())
2083-
return failure();
2084-
20852078
if (LoopWrapperInterface nested = getNestedWrapper()) {
20862079
if (!isComposite())
20872080
return emitError()

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,7 @@ func.func @omp_simd_aligned_mismatch(%arg0 : index, %arg1 : index,
355355
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
356356
omp.yield
357357
}
358+
omp.terminator
358359
}) {alignments = [128],
359360
operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
360361
return
@@ -370,6 +371,7 @@ func.func @omp_simd_aligned_negative(%arg0 : index, %arg1 : index,
370371
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
371372
omp.yield
372373
}
374+
omp.terminator
373375
}) {alignments = [-1, 128], operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
374376
return
375377
}
@@ -384,6 +386,7 @@ func.func @omp_simd_unexpected_alignment(%arg0 : index, %arg1 : index,
384386
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
385387
omp.yield
386388
}
389+
omp.terminator
387390
}) {alignments = [1, 128]} : () -> ()
388391
return
389392
}
@@ -398,6 +401,7 @@ func.func @omp_simd_aligned_float(%arg0 : index, %arg1 : index,
398401
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
399402
omp.yield
400403
}
404+
omp.terminator
401405
}) {alignments = [1.5, 128], operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
402406
return
403407
}
@@ -412,6 +416,7 @@ func.func @omp_simd_aligned_the_same_var(%arg0 : index, %arg1 : index,
412416
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
413417
omp.yield
414418
}
419+
omp.terminator
415420
}) {alignments = [1, 128], operandSegmentSizes = array<i32: 2, 0, 0, 0, 0, 0, 0>} : (memref<i32>, memref<i32>) -> ()
416421
return
417422
}
@@ -426,6 +431,7 @@ func.func @omp_simd_nontemporal_the_same_var(%arg0 : index, %arg1 : index,
426431
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
427432
omp.yield
428433
}
434+
omp.terminator
429435
}) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 2, 0, 0>} : (memref<i32>, memref<i32>) -> ()
430436
return
431437
}
@@ -438,6 +444,7 @@ func.func @omp_simd_order_value(%lb : index, %ub : index, %step : index) {
438444
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
439445
omp.yield
440446
}
447+
omp.terminator
441448
}
442449
return
443450
}
@@ -450,6 +457,7 @@ func.func @omp_simd_reproducible_order(%lb : index, %ub : index, %step : index)
450457
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
451458
omp.yield
452459
}
460+
omp.terminator
453461
}
454462
return
455463
}
@@ -460,6 +468,7 @@ func.func @omp_simd_unconstrained_order(%lb : index, %ub : index, %step : index)
460468
omp.loop_nest (%iv) : index = (%arg0) to (%arg1) step (%arg2) {
461469
omp.yield
462470
}
471+
omp.terminator
463472
}
464473
return
465474
}
@@ -470,6 +479,7 @@ func.func @omp_simd_pretty_simdlen(%lb : index, %ub : index, %step : index) -> (
470479
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
471480
omp.yield
472481
}
482+
omp.terminator
473483
}
474484
return
475485
}
@@ -482,6 +492,7 @@ func.func @omp_simd_pretty_safelen(%lb : index, %ub : index, %step : index) -> (
482492
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
483493
omp.yield
484494
}
495+
omp.terminator
485496
}
486497
return
487498
}
@@ -494,6 +505,7 @@ func.func @omp_simd_pretty_simdlen_safelen(%lb : index, %ub : index, %step : ind
494505
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
495506
omp.yield
496507
}
508+
omp.terminator
497509
}
498510
return
499511
}
@@ -1838,6 +1850,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
18381850
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
18391851
omp.yield
18401852
}
1853+
omp.terminator
18411854
}) {operandSegmentSizes = array<i32: 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
18421855
return
18431856
}
@@ -1852,6 +1865,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
18521865
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
18531866
omp.yield
18541867
}
1868+
omp.terminator
18551869
}) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 0, 0, 0, 0, 2>, reduction_syms = [@add_f32]} : (!llvm.ptr, !llvm.ptr) -> ()
18561870
return
18571871
}
@@ -1865,6 +1879,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
18651879
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
18661880
omp.yield
18671881
}
1882+
omp.terminator
18681883
}) {operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 0, 0, 0, 0, 1>, reduction_syms = [@add_f32, @add_f32]} : (!llvm.ptr) -> ()
18691884
return
18701885
}
@@ -1879,6 +1894,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
18791894
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
18801895
omp.yield
18811896
}
1897+
omp.terminator
18821898
}) {in_reduction_syms = [@add_f32], operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>} : (!llvm.ptr, !llvm.ptr) -> ()
18831899
return
18841900
}
@@ -1892,6 +1908,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
18921908
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
18931909
omp.yield
18941910
}
1911+
omp.terminator
18951912
}) {in_reduction_syms = [@add_f32, @add_f32], operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>} : (!llvm.ptr) -> ()
18961913
return
18971914
}
@@ -1918,6 +1935,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
19181935
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
19191936
omp.yield
19201937
}
1938+
omp.terminator
19211939
}
19221940
return
19231941
}
@@ -1943,6 +1961,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
19431961
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
19441962
omp.yield
19451963
}
1964+
omp.terminator
19461965
}
19471966
return
19481967
}
@@ -1956,6 +1975,7 @@ func.func @taskloop(%lb: i32, %ub: i32, %step: i32) {
19561975
omp.loop_nest (%i, %j) : i32 = (%lb, %ub) to (%ub, %lb) step (%step, %step) {
19571976
omp.yield
19581977
}
1978+
omp.terminator
19591979
}
19601980
return
19611981
}
@@ -2153,31 +2173,37 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
21532173

21542174
// -----
21552175

2156-
func.func @omp_distribute_schedule(%chunk_size : i32) -> () {
2176+
func.func @omp_distribute_schedule(%chunk_size : i32, %lb : i32, %ub : i32, %step : i32) -> () {
21572177
// expected-error @below {{op chunk size set without dist_schedule_static being present}}
21582178
"omp.distribute"(%chunk_size) <{operandSegmentSizes = array<i32: 0, 0, 1, 0>}> ({
2159-
"omp.terminator"() : () -> ()
2160-
}) : (i32) -> ()
2179+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
2180+
"omp.yield"() : () -> ()
2181+
}
2182+
"omp.terminator"() : () -> ()
2183+
}) : (i32) -> ()
21612184
}
21622185

21632186
// -----
21642187

2165-
func.func @omp_distribute_allocate(%data_var : memref<i32>) -> () {
2188+
func.func @omp_distribute_allocate(%data_var : memref<i32>, %lb : i32, %ub : i32, %step : i32) -> () {
21662189
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
21672190
"omp.distribute"(%data_var) <{operandSegmentSizes = array<i32: 1, 0, 0, 0>}> ({
2168-
"omp.terminator"() : () -> ()
2169-
}) : (memref<i32>) -> ()
2191+
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
2192+
"omp.yield"() : () -> ()
2193+
}
2194+
"omp.terminator"() : () -> ()
2195+
}) : (memref<i32>) -> ()
21702196
}
21712197

21722198
// -----
21732199

21742200
func.func @omp_distribute_wrapper(%lb: index, %ub: index, %step: index) -> () {
21752201
// expected-error @below {{op second nested op in loop wrapper is not a terminator}}
21762202
omp.distribute {
2177-
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2178-
"omp.yield"() : () -> ()
2179-
}
2180-
%0 = arith.constant 0 : i32
2203+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2204+
"omp.yield"() : () -> ()
2205+
}
2206+
%0 = arith.constant 0 : i32
21812207
}
21822208
}
21832209

0 commit comments

Comments
 (0)