Skip to content

Commit 5e5b8c4

Browse files
authored
[MLIR][OpenMP] Verify loop wrapper properties of omp.parallel (#88722)
This patch extends verification of the `omp.parallel` operation to check it is correctly defined when taking a loop wrapper role. In OpenMP, a PARALLEL construct can be either a (potenially combined) block construct or a loop construct, when appearing as part of a composite construct. This is currently the case for the DISTRIBUTE PARALLEL DO/FOR and DISTRIBUTE PARALLEL DO/FOR SIMD exclusively. When used to represent the PARALLEL leaf of a composite construct, it must follow the rules of a wrapper loop operation in MLIR, and this is what this patch ensures. No additional restrictions are introduced for PARALLEL block constructs.
1 parent 9dbf3e2 commit 5e5b8c4

File tree

3 files changed

+87
-1
lines changed

3 files changed

+87
-1
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1344,6 +1344,22 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
13441344
}
13451345

13461346
LogicalResult ParallelOp::verify() {
1347+
// Check that it is a valid loop wrapper if it's taking that role.
1348+
if (isa<DistributeOp>((*this)->getParentOp())) {
1349+
if (!isWrapper())
1350+
return emitOpError() << "must take a loop wrapper role if nested inside "
1351+
"of 'omp.distribute'";
1352+
1353+
if (LoopWrapperInterface nested = getNestedWrapper()) {
1354+
// Check for the allowed leaf constructs that may appear in a composite
1355+
// construct directly after PARALLEL.
1356+
if (!isa<WsloopOp>(nested))
1357+
return emitError() << "only supported nested wrapper is 'omp.wsloop'";
1358+
} else {
1359+
return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
1360+
}
1361+
}
1362+
13471363
if (getAllocateVars().size() != getAllocatorsVars().size())
13481364
return emitError(
13491365
"expected equal sizes for allocate and allocator variables");

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,58 @@ func.func @unknown_clause() {
1010

1111
// -----
1212

13+
func.func @not_wrapper() {
14+
omp.distribute {
15+
// expected-error@+1 {{op must take a loop wrapper role if nested inside of 'omp.distribute'}}
16+
omp.parallel {
17+
%0 = arith.constant 0 : i32
18+
omp.terminator
19+
}
20+
omp.terminator
21+
}
22+
23+
return
24+
}
25+
26+
// -----
27+
28+
func.func @invalid_nested_wrapper(%lb : index, %ub : index, %step : index) {
29+
omp.distribute {
30+
// expected-error@+1 {{only supported nested wrapper is 'omp.wsloop'}}
31+
omp.parallel {
32+
omp.simd {
33+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
34+
omp.yield
35+
}
36+
omp.terminator
37+
}
38+
omp.terminator
39+
}
40+
omp.terminator
41+
}
42+
43+
return
44+
}
45+
46+
// -----
47+
48+
func.func @no_nested_wrapper(%lb : index, %ub : index, %step : index) {
49+
omp.distribute {
50+
// expected-error@+1 {{op must not wrap an 'omp.loop_nest' directly}}
51+
omp.parallel {
52+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
53+
omp.yield
54+
}
55+
omp.terminator
56+
}
57+
omp.terminator
58+
}
59+
60+
return
61+
}
62+
63+
// -----
64+
1365
func.func @if_once(%n : i1) {
1466
// expected-error@+1 {{`if` clause can appear at most once in the expansion of the oilist directive}}
1567
omp.parallel if(%n : i1) if(%n : i1) {

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ func.func @omp_terminator() -> () {
5151
omp.terminator
5252
}
5353

54-
func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32) -> () {
54+
func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i32, %idx : index) -> () {
5555
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32) allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
5656
"omp.parallel" (%if_cond, %num_threads, %data_var, %data_var) ({
5757

@@ -85,6 +85,24 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
8585
omp.terminator
8686
}) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (memref<i32>, memref<i32>) -> ()
8787

88+
// CHECK: omp.distribute
89+
omp.distribute {
90+
// CHECK-NEXT: omp.parallel
91+
omp.parallel {
92+
// CHECK-NEXT: omp.wsloop
93+
// TODO Remove induction variables from omp.wsloop.
94+
omp.wsloop for (%iv) : index = (%idx) to (%idx) step (%idx) {
95+
// CHECK-NEXT: omp.loop_nest
96+
omp.loop_nest (%iv2) : index = (%idx) to (%idx) step (%idx) {
97+
omp.yield
98+
}
99+
omp.terminator
100+
}
101+
omp.terminator
102+
}
103+
omp.terminator
104+
}
105+
88106
return
89107
}
90108

0 commit comments

Comments
 (0)