Skip to content

Commit 468029e

Browse files
committed
Fix split-outer2 test
1 parent 9fb0bde commit 468029e

File tree

1 file changed

+37
-32
lines changed

1 file changed

+37
-32
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 37 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,7 +1961,7 @@ LogicalResult TargetOp::verifyRegions() {
19611961
}
19621962

19631963
static Operation *
1964-
findCapturedOmpOp(Operation *rootOp,
1964+
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
19651965
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
19661966
assert(rootOp && "expected valid operation");
19671967

@@ -1989,17 +1989,19 @@ findCapturedOmpOp(Operation *rootOp,
19891989
// (i.e. its block's successors can reach it) or if it's not guaranteed to
19901990
// be executed before all exits of the region (i.e. it doesn't dominate all
19911991
// blocks with no successors reachable from the entry block).
1992-
Region *parentRegion = op->getParentRegion();
1993-
Block *parentBlock = op->getBlock();
1994-
1995-
for (Block *successor : parentBlock->getSuccessors())
1996-
if (successor->isReachable(parentBlock))
1997-
return WalkResult::interrupt();
1998-
1999-
for (Block &block : *parentRegion)
2000-
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2001-
!domInfo.dominates(parentBlock, &block))
2002-
return WalkResult::interrupt();
1992+
if (checkSingleMandatoryExec) {
1993+
Region *parentRegion = op->getParentRegion();
1994+
Block *parentBlock = op->getBlock();
1995+
1996+
for (Block *successor : parentBlock->getSuccessors())
1997+
if (successor->isReachable(parentBlock))
1998+
return WalkResult::interrupt();
1999+
2000+
for (Block &block : *parentRegion)
2001+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2002+
!domInfo.dominates(parentBlock, &block))
2003+
return WalkResult::interrupt();
2004+
}
20032005

20042006
// Don't capture this op if it has a not-allowed sibling, and stop recursing
20052007
// into nested operations.
@@ -2022,25 +2024,27 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20222024

20232025
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
20242026
// effects, but don't include a memory write effect.
2025-
return findCapturedOmpOp(*this, [&](Operation *sibling) {
2026-
if (!sibling)
2027-
return false;
2028-
2029-
if (ompDialect == sibling->getDialect())
2030-
return sibling->hasTrait<OpTrait::IsTerminator>();
2031-
2032-
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2033-
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2034-
effects;
2035-
memOp.getEffects(effects);
2036-
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
2037-
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2038-
isa<SideEffects::AutomaticAllocationScopeResource>(
2039-
effect.getResource());
2027+
return findCapturedOmpOp(
2028+
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2029+
if (!sibling)
2030+
return false;
2031+
2032+
if (ompDialect == sibling->getDialect())
2033+
return sibling->hasTrait<OpTrait::IsTerminator>();
2034+
2035+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2036+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2037+
effects;
2038+
memOp.getEffects(effects);
2039+
return !llvm::any_of(
2040+
effects, [&](MemoryEffects::EffectInstance &effect) {
2041+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2042+
isa<SideEffects::AutomaticAllocationScopeResource>(
2043+
effect.getResource());
2044+
});
2045+
}
2046+
return true;
20402047
});
2041-
}
2042-
return true;
2043-
});
20442048
}
20452049

20462050
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
@@ -2118,8 +2122,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
21182122
// `kmpc_distribute_static_loop` family of functions produced by the
21192123
// OMPIRBuilder for these kernels prevents that from working.
21202124
Dialect *ompDialect = targetOp->getDialect();
2121-
Operation *nestedCapture =
2122-
findCapturedOmpOp(capturedOp, [&](Operation *sibling) {
2125+
Operation *nestedCapture = findCapturedOmpOp(
2126+
capturedOp, /*checkSingleMandatoryExec=*/false,
2127+
[&](Operation *sibling) {
21232128
return sibling && (ompDialect != sibling->getDialect() ||
21242129
sibling->hasTrait<OpTrait::IsTerminator>());
21252130
});

0 commit comments

Comments
 (0)