Skip to content

Commit 66d0f84

Browse files
committed
[mlir][OpenMP] allow cancellation to not be directly nested
omp.cancel and omp.cancellationpoint contain an attribute describing the type of parent construct which should be cancelled. e.g. ``` !$omp cancel do ``` Must be inside of a wsloop. Previously the verifer required the immediate parent to be this operation. This is not quite right because something like the following is valid: ``` !$omp parallel do do i = 1, N if (cond) then !$omp cancel do endif enddo ``` This patch relaxes the verifier to only require that some parent operation matches (not necessarily the immediate parent).
1 parent 581f8bc commit 66d0f84

File tree

2 files changed

+92
-22
lines changed

2 files changed

+92
-22
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3103,22 +3103,15 @@ void CancelOp::build(OpBuilder &builder, OperationState &state,
31033103

31043104
LogicalResult CancelOp::verify() {
31053105
ClauseCancellationConstructType cct = getCancelDirective();
3106-
Operation *parentOp = (*this)->getParentOp();
3107-
3108-
if (!parentOp) {
3109-
return emitOpError() << "must be used within a region supporting "
3110-
"cancel directive";
3111-
}
3106+
Operation *thisOp = (*this).getOperation();
31123107

31133108
if ((cct == ClauseCancellationConstructType::Parallel) &&
3114-
!isa<ParallelOp>(parentOp)) {
3109+
!thisOp->getParentOfType<ParallelOp>()) {
31153110
return emitOpError() << "cancel parallel must appear "
31163111
<< "inside a parallel region";
31173112
}
31183113
if (cct == ClauseCancellationConstructType::Loop) {
3119-
auto loopOp = dyn_cast<LoopNestOp>(parentOp);
3120-
auto wsloopOp = llvm::dyn_cast_if_present<WsloopOp>(
3121-
loopOp ? loopOp->getParentOp() : nullptr);
3114+
auto wsloopOp = thisOp->getParentOfType<WsloopOp>();
31223115

31233116
if (!wsloopOp) {
31243117
return emitOpError()
@@ -3134,12 +3127,12 @@ LogicalResult CancelOp::verify() {
31343127
}
31353128

31363129
} else if (cct == ClauseCancellationConstructType::Sections) {
3137-
if (!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
3130+
auto sectionsOp = thisOp->getParentOfType<SectionsOp>();
3131+
if (!sectionsOp) {
31383132
return emitOpError() << "cancel sections must appear "
31393133
<< "inside a sections region";
31403134
}
3141-
if (isa_and_nonnull<SectionsOp>(parentOp->getParentOp()) &&
3142-
cast<SectionsOp>(parentOp->getParentOp()).getNowaitAttr()) {
3135+
if (sectionsOp.getNowait()) {
31433136
return emitError() << "A sections construct that is canceled "
31443137
<< "must not have a nowait clause";
31453138
}
@@ -3159,25 +3152,20 @@ void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
31593152

31603153
LogicalResult CancellationPointOp::verify() {
31613154
ClauseCancellationConstructType cct = getCancelDirective();
3162-
Operation *parentOp = (*this)->getParentOp();
3163-
3164-
if (!parentOp) {
3165-
return emitOpError() << "must be used within a region supporting "
3166-
"cancellation point directive";
3167-
}
3155+
Operation *thisOp = (*this).getOperation();
31683156

31693157
if ((cct == ClauseCancellationConstructType::Parallel) &&
3170-
!(isa<ParallelOp>(parentOp))) {
3158+
!thisOp->getParentOfType<ParallelOp>()) {
31713159
return emitOpError() << "cancellation point parallel must appear "
31723160
<< "inside a parallel region";
31733161
}
31743162
if ((cct == ClauseCancellationConstructType::Loop) &&
3175-
(!isa<LoopNestOp>(parentOp) || !isa<WsloopOp>(parentOp->getParentOp()))) {
3163+
!thisOp->getParentOfType<WsloopOp>()) {
31763164
return emitOpError() << "cancellation point loop must appear "
31773165
<< "inside a worksharing-loop region";
31783166
}
31793167
if ((cct == ClauseCancellationConstructType::Sections) &&
3180-
!(isa<SectionsOp>(parentOp) || isa<SectionOp>(parentOp))) {
3168+
!thisOp->getParentOfType<SectionsOp>()) {
31813169
return emitOpError() << "cancellation point sections must appear "
31823170
<< "inside a sections region";
31833171
}

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2201,6 +2201,48 @@ func.func @omp_cancel_sections() -> () {
22012201
return
22022202
}
22032203

2204+
func.func @omp_cancel_parallel_nested(%if_cond : i1) -> () {
2205+
omp.parallel {
2206+
scf.if %if_cond {
2207+
// CHECK: omp.cancel cancellation_construct_type(parallel)
2208+
omp.cancel cancellation_construct_type(parallel)
2209+
}
2210+
// CHECK: omp.terminator
2211+
omp.terminator
2212+
}
2213+
return
2214+
}
2215+
2216+
func.func @omp_cancel_wsloop_nested(%lb : index, %ub : index, %step : index,
2217+
%if_cond : i1) {
2218+
omp.wsloop {
2219+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2220+
scf.if %if_cond {
2221+
// CHECK: omp.cancel cancellation_construct_type(loop)
2222+
omp.cancel cancellation_construct_type(loop)
2223+
}
2224+
// CHECK: omp.yield
2225+
omp.yield
2226+
}
2227+
}
2228+
return
2229+
}
2230+
2231+
func.func @omp_cancel_sections_nested(%if_cond : i1) -> () {
2232+
omp.sections {
2233+
omp.section {
2234+
scf.if %if_cond {
2235+
// CHECK: omp.cancel cancellation_construct_type(sections)
2236+
omp.cancel cancellation_construct_type(sections)
2237+
}
2238+
omp.terminator
2239+
}
2240+
// CHECK: omp.terminator
2241+
omp.terminator
2242+
}
2243+
return
2244+
}
2245+
22042246
func.func @omp_cancellationpoint_parallel() -> () {
22052247
omp.parallel {
22062248
// CHECK: omp.cancellation_point cancellation_construct_type(parallel)
@@ -2241,6 +2283,46 @@ func.func @omp_cancellationpoint_sections() -> () {
22412283
return
22422284
}
22432285

2286+
func.func @omp_cancellationpoint_parallel_nested(%if_cond : i1) -> () {
2287+
omp.parallel {
2288+
scf.if %if_cond {
2289+
// CHECK: omp.cancellation_point cancellation_construct_type(parallel)
2290+
omp.cancellation_point cancellation_construct_type(parallel)
2291+
}
2292+
omp.terminator
2293+
}
2294+
return
2295+
}
2296+
2297+
func.func @omp_cancellationpoint_wsloop_nested(%lb : index, %ub : index, %step : index, %if_cond : i1) {
2298+
omp.wsloop {
2299+
omp.loop_nest (%iv) : index = (%lb) to (%ub) step (%step) {
2300+
scf.if %if_cond {
2301+
// CHECK: omp.cancellation_point cancellation_construct_type(loop)
2302+
omp.cancellation_point cancellation_construct_type(loop)
2303+
}
2304+
// CHECK: omp.yield
2305+
omp.yield
2306+
}
2307+
}
2308+
return
2309+
}
2310+
2311+
func.func @omp_cancellationpoint_sections_nested(%if_cond : i1) -> () {
2312+
omp.sections {
2313+
omp.section {
2314+
scf.if %if_cond {
2315+
// CHECK: omp.cancellation_point cancellation_construct_type(sections)
2316+
omp.cancellation_point cancellation_construct_type(sections)
2317+
}
2318+
omp.terminator
2319+
}
2320+
// CHECK: omp.terminator
2321+
omp.terminator
2322+
}
2323+
return
2324+
}
2325+
22442326
// CHECK-LABEL: @omp_taskgroup_no_tasks
22452327
func.func @omp_taskgroup_no_tasks() -> () {
22462328

0 commit comments

Comments
 (0)