@@ -1961,7 +1961,7 @@ LogicalResult TargetOp::verifyRegions() {
1961
1961
}
1962
1962
1963
1963
static Operation *
1964
- findCapturedOmpOp (Operation *rootOp,
1964
+ findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
1965
1965
llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
1966
1966
assert (rootOp && " expected valid operation" );
1967
1967
@@ -1989,17 +1989,19 @@ findCapturedOmpOp(Operation *rootOp,
1989
1989
// (i.e. its block's successors can reach it) or if it's not guaranteed to
1990
1990
// be executed before all exits of the region (i.e. it doesn't dominate all
1991
1991
// blocks with no successors reachable from the entry block).
1992
- Region *parentRegion = op->getParentRegion ();
1993
- Block *parentBlock = op->getBlock ();
1994
-
1995
- for (Block *successor : parentBlock->getSuccessors ())
1996
- if (successor->isReachable (parentBlock))
1997
- return WalkResult::interrupt ();
1998
-
1999
- for (Block &block : *parentRegion)
2000
- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2001
- !domInfo.dominates (parentBlock, &block))
2002
- return WalkResult::interrupt ();
1992
+ if (checkSingleMandatoryExec) {
1993
+ Region *parentRegion = op->getParentRegion ();
1994
+ Block *parentBlock = op->getBlock ();
1995
+
1996
+ for (Block *successor : parentBlock->getSuccessors ())
1997
+ if (successor->isReachable (parentBlock))
1998
+ return WalkResult::interrupt ();
1999
+
2000
+ for (Block &block : *parentRegion)
2001
+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2002
+ !domInfo.dominates (parentBlock, &block))
2003
+ return WalkResult::interrupt ();
2004
+ }
2003
2005
2004
2006
// Don't capture this op if it has a not-allowed sibling, and stop recursing
2005
2007
// into nested operations.
@@ -2022,25 +2024,27 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2022
2024
2023
2025
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
2024
2026
// effects, but don't include a memory write effect.
2025
- return findCapturedOmpOp (*this , [&](Operation *sibling) {
2026
- if (!sibling)
2027
- return false ;
2028
-
2029
- if (ompDialect == sibling->getDialect ())
2030
- return sibling->hasTrait <OpTrait::IsTerminator>();
2031
-
2032
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2033
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2034
- effects;
2035
- memOp.getEffects (effects);
2036
- return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
2037
- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2038
- isa<SideEffects::AutomaticAllocationScopeResource>(
2039
- effect.getResource ());
2027
+ return findCapturedOmpOp (
2028
+ *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2029
+ if (!sibling)
2030
+ return false ;
2031
+
2032
+ if (ompDialect == sibling->getDialect ())
2033
+ return sibling->hasTrait <OpTrait::IsTerminator>();
2034
+
2035
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2036
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2037
+ effects;
2038
+ memOp.getEffects (effects);
2039
+ return !llvm::any_of (
2040
+ effects, [&](MemoryEffects::EffectInstance &effect) {
2041
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2042
+ isa<SideEffects::AutomaticAllocationScopeResource>(
2043
+ effect.getResource ());
2044
+ });
2045
+ }
2046
+ return true ;
2040
2047
});
2041
- }
2042
- return true ;
2043
- });
2044
2048
}
2045
2049
2046
2050
TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
@@ -2118,8 +2122,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2118
2122
// `kmpc_distribute_static_loop` family of functions produced by the
2119
2123
// OMPIRBuilder for these kernels prevents that from working.
2120
2124
Dialect *ompDialect = targetOp->getDialect ();
2121
- Operation *nestedCapture =
2122
- findCapturedOmpOp (capturedOp, [&](Operation *sibling) {
2125
+ Operation *nestedCapture = findCapturedOmpOp (
2126
+ capturedOp, /* checkSingleMandatoryExec=*/ false ,
2127
+ [&](Operation *sibling) {
2123
2128
return sibling && (ompDialect != sibling->getDialect () ||
2124
2129
sibling->hasTrait <OpTrait::IsTerminator>());
2125
2130
});
0 commit comments