Skip to content

Commit dc1ac96

Browse files
committed
Accept omp.parallel inside of conditional or loop in Generic-SPMD
1 parent 08b4e24 commit dc1ac96

File tree

1 file changed

+37
-32
lines changed

1 file changed

+37
-32
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,7 +1954,7 @@ LogicalResult TargetOp::verifyRegions() {
19541954
}
19551955

19561956
static Operation *
1957-
findCapturedOmpOp(Operation *rootOp,
1957+
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
19581958
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
19591959
assert(rootOp && "expected valid operation");
19601960

@@ -1982,17 +1982,19 @@ findCapturedOmpOp(Operation *rootOp,
19821982
// (i.e. its block's successors can reach it) or if it's not guaranteed to
19831983
// be executed before all exits of the region (i.e. it doesn't dominate all
19841984
// blocks with no successors reachable from the entry block).
1985-
Region *parentRegion = op->getParentRegion();
1986-
Block *parentBlock = op->getBlock();
1987-
1988-
for (Block *successor : parentBlock->getSuccessors())
1989-
if (successor->isReachable(parentBlock))
1990-
return WalkResult::interrupt();
1991-
1992-
for (Block &block : *parentRegion)
1993-
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1994-
!domInfo.dominates(parentBlock, &block))
1995-
return WalkResult::interrupt();
1985+
if (checkSingleMandatoryExec) {
1986+
Region *parentRegion = op->getParentRegion();
1987+
Block *parentBlock = op->getBlock();
1988+
1989+
for (Block *successor : parentBlock->getSuccessors())
1990+
if (successor->isReachable(parentBlock))
1991+
return WalkResult::interrupt();
1992+
1993+
for (Block &block : *parentRegion)
1994+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1995+
!domInfo.dominates(parentBlock, &block))
1996+
return WalkResult::interrupt();
1997+
}
19961998

19971999
// Don't capture this op if it has a not-allowed sibling, and stop recursing
19982000
// into nested operations.
@@ -2015,25 +2017,27 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20152017

20162018
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
20172019
// effects, but don't include a memory write effect.
2018-
return findCapturedOmpOp(*this, [&](Operation *sibling) {
2019-
if (!sibling)
2020-
return false;
2021-
2022-
if (ompDialect == sibling->getDialect())
2023-
return sibling->hasTrait<OpTrait::IsTerminator>();
2024-
2025-
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2026-
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2027-
effects;
2028-
memOp.getEffects(effects);
2029-
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
2030-
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2031-
isa<SideEffects::AutomaticAllocationScopeResource>(
2032-
effect.getResource());
2020+
return findCapturedOmpOp(
2021+
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2022+
if (!sibling)
2023+
return false;
2024+
2025+
if (ompDialect == sibling->getDialect())
2026+
return sibling->hasTrait<OpTrait::IsTerminator>();
2027+
2028+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2029+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2030+
effects;
2031+
memOp.getEffects(effects);
2032+
return !llvm::any_of(
2033+
effects, [&](MemoryEffects::EffectInstance &effect) {
2034+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2035+
isa<SideEffects::AutomaticAllocationScopeResource>(
2036+
effect.getResource());
2037+
});
2038+
}
2039+
return true;
20332040
});
2034-
}
2035-
return true;
2036-
});
20372041
}
20382042

20392043
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
@@ -2108,8 +2112,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21082112
// `kmpc_distribute_static_loop` family of functions produced by the
21092113
// OMPIRBuilder for these kernels prevents that from working.
21102114
Dialect *ompDialect = targetOp->getDialect();
2111-
Operation *nestedCapture =
2112-
findCapturedOmpOp(capturedOp, [&](Operation *sibling) {
2115+
Operation *nestedCapture = findCapturedOmpOp(
2116+
capturedOp, /*checkSingleMandatoryExec=*/false,
2117+
[&](Operation *sibling) {
21132118
return sibling && (ompDialect != sibling->getDialect() ||
21142119
sibling->hasTrait<OpTrait::IsTerminator>());
21152120
});

0 commit comments

Comments
 (0)