Skip to content

[MLIR][OpenMP] Improve loop wrapper op verifiers #134833

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,7 @@ def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
];
let assemblyFormat = "$region attr-dict";
let hasVerifier = 1;
let hasRegionVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
11 changes: 10 additions & 1 deletion mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,15 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
single region with a single block in which there's a single operation and a
terminator. That nested operation must be another loop wrapper or an
`omp.loop_nest`.

Operation-specific verifiers should make the following checks in their
verifier, additionally to what the interface itself checks:
- If `getNestedWrapper() != nullptr`, is the type of the nested wrapper
allowed in that context? This check might require looking at the parent as
well.
- If the operation is a `ComposableOpInterface`, check that it is
consistent with the potential existence of a `LoopWrapperInterface` parent
and whether `getNestedWrapper() != nullptr`.
}];

let cppNamespace = "::mlir::omp";
Expand Down Expand Up @@ -255,7 +264,7 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
];

let extraClassDeclaration = [{
/// Interface verifier imlementation.
/// Interface verifier implementation.
llvm::LogicalResult verifyImpl();
}];

Expand Down
38 changes: 27 additions & 11 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2222,21 +2222,27 @@ LogicalResult ParallelOp::verify() {
}

LogicalResult ParallelOp::verifyRegions() {
auto distributeChildOps = getOps<DistributeOp>();
if (!distributeChildOps.empty()) {
auto distChildOps = getOps<DistributeOp>();
int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
if (numDistChildOps > 1)
return emitError()
<< "multiple 'omp.distribute' nested inside of 'omp.parallel'";

if (numDistChildOps == 1) {
if (!isComposite())
return emitError()
<< "'omp.composite' attribute missing from composite operation";

auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
Operation &distributeOp = **distributeChildOps.begin();
Operation &distributeOp = **distChildOps.begin();
for (Operation &childOp : getOps()) {
if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
continue;

if (!childOp.hasTrait<OpTrait::IsTerminator>())
return emitError() << "unexpected OpenMP operation inside of composite "
"'omp.parallel'";
"'omp.parallel': "
<< childOp.getName();
}
} else if (isComposite()) {
return emitError()
Expand Down Expand Up @@ -2388,9 +2394,15 @@ void WorkshareOp::build(OpBuilder &builder, OperationState &state,

LogicalResult WorkshareLoopWrapperOp::verify() {
if (!(*this)->getParentOfType<WorkshareOp>())
return emitError() << "must be nested in an omp.workshare";
if (getNestedWrapper())
return emitError() << "cannot be composite";
return emitOpError() << "must be nested in an omp.workshare";
return success();
}

LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
getNestedWrapper())
return emitOpError() << "expected to be a standalone loop wrapper";

return success();
}

Expand All @@ -2415,7 +2427,7 @@ LogicalResult LoopWrapperInterface::verifyImpl() {

Operation &firstOp = *region.op_begin();
if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
return emitOpError() << "op nested in loop wrapper is not another loop "
return emitOpError() << "nested in loop wrapper is not another loop "
"wrapper or `omp.loop_nest`";

return success();
Expand Down Expand Up @@ -2444,7 +2456,7 @@ LogicalResult LoopOp::verify() {
LogicalResult LoopOp::verifyRegions() {
if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
getNestedWrapper())
return emitError() << "`omp.loop` expected to be a standalone loop wrapper";
return emitOpError() << "expected to be a standalone loop wrapper";

return success();
}
Expand Down Expand Up @@ -2601,9 +2613,13 @@ LogicalResult DistributeOp::verifyRegions() {
// Check for the allowed leaf constructs that may appear in a composite
// construct directly after DISTRIBUTE.
if (isa<WsloopOp>(nested)) {
if (!llvm::dyn_cast_if_present<ParallelOp>((*this)->getParentOp()))
Operation *parentOp = (*this)->getParentOp();
if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
!cast<ComposableOpInterface>(parentOp).isComposite()) {
return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
"when 'omp.parallel' is the direct parent";
"when a composite 'omp.parallel' is the direct "
"parent";
}
} else if (!isa<SimdOp>(nested))
return emitError() << "only supported nested wrappers are 'omp.simd' and "
"'omp.wsloop'";
Expand Down
61 changes: 49 additions & 12 deletions mlir/test/Dialect/OpenMP/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,7 @@ func.func @omp_distribute_allocate(%data_var : memref<i32>, %lb : i32, %ub : i32
// -----

func.func @omp_distribute_nested_wrapper(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{an 'omp.wsloop' nested wrapper is only allowed when 'omp.parallel' is the direct parent}}
// expected-error @below {{an 'omp.wsloop' nested wrapper is only allowed when a composite 'omp.parallel' is the direct parent}}
omp.distribute {
"omp.wsloop"() ({
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
Expand Down Expand Up @@ -2429,6 +2429,22 @@ func.func @omp_distribute_nested_wrapper3(%lb: index, %ub: index, %step: index)

// -----

func.func @omp_distribute_nested_wrapper4(%lb: index, %ub: index, %step: index) -> () {
omp.parallel {
// expected-error @below {{an 'omp.wsloop' nested wrapper is only allowed when a composite 'omp.parallel' is the direct parent}}
omp.distribute {
"omp.wsloop"() ({
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
"omp.yield"() : () -> ()
}
}) {omp.composite} : () -> ()
} {omp.composite}
omp.terminator
}
}

// -----

func.func @omp_distribute_order() -> () {
// expected-error @below {{invalid clause value: 'default'}}
omp.distribute order(default) {
Expand Down Expand Up @@ -2623,15 +2639,13 @@ func.func @masked_arg_count_mismatch(%arg0: i32, %arg1: i32) {

// -----
func.func @omp_parallel_missing_composite(%lb: index, %ub: index, %step: index) -> () {
// expected-error@+1 {{'omp.composite' attribute missing from composite operation}}
// expected-error @below {{'omp.composite' attribute missing from composite operation}}
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
}
omp.terminator
}
return
Expand All @@ -2653,7 +2667,7 @@ func.func @omp_parallel_invalid_composite(%lb: index, %ub: index, %step: index)

// -----
func.func @omp_parallel_invalid_composite2(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{unexpected OpenMP operation inside of composite 'omp.parallel'}}
// expected-error @below {{unexpected OpenMP operation inside of composite 'omp.parallel': omp.barrier}}
omp.parallel {
omp.barrier
omp.distribute {
Expand All @@ -2668,6 +2682,29 @@ func.func @omp_parallel_invalid_composite2(%lb: index, %ub: index, %step: index)
return
}

// -----
func.func @omp_parallel_invalid_composite3(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{multiple 'omp.distribute' nested inside of 'omp.parallel'}}
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
} {omp.composite}
return
}

// -----
func.func @omp_wsloop_missing_composite(%lb: index, %ub: index, %step: index) -> () {
// expected-error @below {{'omp.composite' attribute missing from composite wrapper}}
Expand Down Expand Up @@ -2787,7 +2824,7 @@ func.func @omp_taskloop_invalid_composite(%lb: index, %ub: index, %step: index)

func.func @omp_loop_invalid_nesting(%lb : index, %ub : index, %step : index) {

// expected-error @below {{`omp.loop` expected to be a standalone loop wrapper}}
// expected-error @below {{'omp.loop' op expected to be a standalone loop wrapper}}
omp.loop {
omp.simd {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
Expand All @@ -2804,7 +2841,7 @@ func.func @omp_loop_invalid_nesting(%lb : index, %ub : index, %step : index) {
func.func @omp_loop_invalid_nesting2(%lb : index, %ub : index, %step : index) {

omp.simd {
// expected-error @below {{`omp.loop` expected to be a standalone loop wrapper}}
// expected-error @below {{'omp.loop' op expected to be a standalone loop wrapper}}
omp.loop {
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
omp.yield
Expand All @@ -2831,7 +2868,7 @@ func.func @omp_loop_invalid_binding(%lb : index, %ub : index, %step : index) {
// -----
func.func @nested_wrapper(%idx : index) {
omp.workshare {
// expected-error @below {{cannot be composite}}
// expected-error @below {{'omp.workshare.loop_wrapper' op expected to be a standalone loop wrapper}}
omp.workshare.loop_wrapper {
omp.simd {
omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
Expand Down