Skip to content

Commit dd6ec89

Browse files
committed
Require openmp constructs are statically nested
1 parent a59c703 commit dd6ec89

File tree

2 files changed

+47
-8
lines changed

2 files changed

+47
-8
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3101,17 +3101,32 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
31013101
CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
31023102
}
31033103

3104+
static Operation *getParentInSameDialect(Operation *thisOp) {
3105+
mlir::Operation *parent = thisOp->getParentOp();
3106+
while (parent) {
3107+
if (parent->getDialect() == thisOp->getDialect())
3108+
return parent;
3109+
parent = parent->getParentOp();
3110+
}
3111+
return nullptr;
3112+
}
3113+
31043114
LogicalResult CancelOp::verify() {
31053115
ClauseCancellationConstructType cct = getCancelDirective();
3106-
Operation *thisOp = (*this).getOperation();
3116+
// The next OpenMP operation in the chain of parents
3117+
Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3118+
if (!structuralParent)
3119+
return emitOpError() << "Orphaned cancel construct";
31073120

31083121
if ((cct == ClauseCancellationConstructType::Parallel) &&
3109-
!thisOp->getParentOfType<ParallelOp>()) {
3122+
!mlir::isa<ParallelOp>(structuralParent)) {
31103123
return emitOpError() << "cancel parallel must appear "
31113124
<< "inside a parallel region";
31123125
}
31133126
if (cct == ClauseCancellationConstructType::Loop) {
3114-
auto wsloopOp = thisOp->getParentOfType<WsloopOp>();
3127+
// structural parent will be omp.loop_nest, directly nested inside
3128+
// omp.wsloop
3129+
auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
31153130

31163131
if (!wsloopOp) {
31173132
return emitOpError()
@@ -3127,7 +3142,10 @@ LogicalResult CancelOp::verify() {
31273142
}
31283143

31293144
} else if (cct == ClauseCancellationConstructType::Sections) {
3130-
auto sectionsOp = thisOp->getParentOfType<SectionsOp>();
3145+
// structural parent will be an omp.section, directly nested inside
3146+
// omp.sections
3147+
auto sectionsOp =
3148+
mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
31313149
if (!sectionsOp) {
31323150
return emitOpError() << "cancel sections must appear "
31333151
<< "inside a sections region";
@@ -3152,20 +3170,25 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
31523170

31533171
LogicalResult CancellationPointOp::verify() {
31543172
ClauseCancellationConstructType cct = getCancelDirective();
3155-
Operation *thisOp = (*this).getOperation();
3173+
// The next OpenMP operation in the chain of parents
3174+
Operation *structuralParent = getParentInSameDialect((*this).getOperation());
3175+
if (!structuralParent)
3176+
return emitOpError() << "Orphaned cancellation point";
31563177

31573178
if ((cct == ClauseCancellationConstructType::Parallel) &&
3158-
!thisOp->getParentOfType<ParallelOp>()) {
3179+
!mlir::isa<ParallelOp>(structuralParent)) {
31593180
return emitOpError() << "cancellation point parallel must appear "
31603181
<< "inside a parallel region";
31613182
}
3183+
// Strucutal parent here will be an omp.loop_nest. Get the parent of that to
3184+
// find the wsloop
31623185
if ((cct == ClauseCancellationConstructType::Loop) &&
3163-
!thisOp->getParentOfType<WsloopOp>()) {
3186+
!mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
31643187
return emitOpError() << "cancellation point loop must appear "
31653188
<< "inside a worksharing-loop region";
31663189
}
31673190
if ((cct == ClauseCancellationConstructType::Sections) &&
3168-
!thisOp->getParentOfType<SectionsOp>()) {
3191+
!mlir::isa<omp::SectionOp>(structuralParent)) {
31693192
return emitOpError() << "cancellation point sections must appear "
31703193
<< "inside a sections region";
31713194
}

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,14 @@ func.func @omp_task(%mem: memref<1xf32>) {
17101710

17111711
// -----
17121712

1713+
func.func @omp_cancel() {
1714+
// expected-error @below {{Orphaned cancel construct}}
1715+
omp.cancel cancellation_construct_type(parallel)
1716+
return
1717+
}
1718+
1719+
// -----
1720+
17131721
func.func @omp_cancel() {
17141722
omp.sections {
17151723
// expected-error @below {{cancel parallel must appear inside a parallel region}}
@@ -1789,6 +1797,14 @@ func.func @omp_cancel5() -> () {
17891797

17901798
// -----
17911799

1800+
func.func @omp_cancellationpoint() {
1801+
// expected-error @below {{Orphaned cancellation point}}
1802+
omp.cancellation_point cancellation_construct_type(parallel)
1803+
return
1804+
}
1805+
1806+
// -----
1807+
17921808
func.func @omp_cancellationpoint() {
17931809
omp.sections {
17941810
// expected-error @below {{cancellation point parallel must appear inside a parallel region}}

0 commit comments

Comments
 (0)