@@ -1954,7 +1954,7 @@ LogicalResult TargetOp::verifyRegions() {
1954
1954
}
1955
1955
1956
1956
static Operation *
1957
- findCapturedOmpOp (Operation *rootOp,
1957
+ findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
1958
1958
llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
1959
1959
assert (rootOp && " expected valid operation" );
1960
1960
@@ -1982,17 +1982,19 @@ findCapturedOmpOp(Operation *rootOp,
1982
1982
// (i.e. its block's successors can reach it) or if it's not guaranteed to
1983
1983
// be executed before all exits of the region (i.e. it doesn't dominate all
1984
1984
// blocks with no successors reachable from the entry block).
1985
- Region *parentRegion = op->getParentRegion ();
1986
- Block *parentBlock = op->getBlock ();
1987
-
1988
- for (Block *successor : parentBlock->getSuccessors ())
1989
- if (successor->isReachable (parentBlock))
1990
- return WalkResult::interrupt ();
1991
-
1992
- for (Block &block : *parentRegion)
1993
- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1994
- !domInfo.dominates (parentBlock, &block))
1995
- return WalkResult::interrupt ();
1985
+ if (checkSingleMandatoryExec) {
1986
+ Region *parentRegion = op->getParentRegion ();
1987
+ Block *parentBlock = op->getBlock ();
1988
+
1989
+ for (Block *successor : parentBlock->getSuccessors ())
1990
+ if (successor->isReachable (parentBlock))
1991
+ return WalkResult::interrupt ();
1992
+
1993
+ for (Block &block : *parentRegion)
1994
+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1995
+ !domInfo.dominates (parentBlock, &block))
1996
+ return WalkResult::interrupt ();
1997
+ }
1996
1998
1997
1999
// Don't capture this op if it has a not-allowed sibling, and stop recursing
1998
2000
// into nested operations.
@@ -2015,25 +2017,27 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2015
2017
2016
2018
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
2017
2019
// effects, but don't include a memory write effect.
2018
- return findCapturedOmpOp (*this , [&](Operation *sibling) {
2019
- if (!sibling)
2020
- return false ;
2021
-
2022
- if (ompDialect == sibling->getDialect ())
2023
- return sibling->hasTrait <OpTrait::IsTerminator>();
2024
-
2025
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2026
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2027
- effects;
2028
- memOp.getEffects (effects);
2029
- return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
2030
- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2031
- isa<SideEffects::AutomaticAllocationScopeResource>(
2032
- effect.getResource ());
2020
+ return findCapturedOmpOp (
2021
+ *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2022
+ if (!sibling)
2023
+ return false ;
2024
+
2025
+ if (ompDialect == sibling->getDialect ())
2026
+ return sibling->hasTrait <OpTrait::IsTerminator>();
2027
+
2028
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2029
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2030
+ effects;
2031
+ memOp.getEffects (effects);
2032
+ return !llvm::any_of (
2033
+ effects, [&](MemoryEffects::EffectInstance &effect) {
2034
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2035
+ isa<SideEffects::AutomaticAllocationScopeResource>(
2036
+ effect.getResource ());
2037
+ });
2038
+ }
2039
+ return true ;
2033
2040
});
2034
- }
2035
- return true ;
2036
- });
2037
2041
}
2038
2042
2039
2043
TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
@@ -2108,8 +2112,9 @@ TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2108
2112
// `kmpc_distribute_static_loop` family of functions produced by the
2109
2113
// OMPIRBuilder for these kernels prevents that from working.
2110
2114
Dialect *ompDialect = targetOp->getDialect ();
2111
- Operation *nestedCapture =
2112
- findCapturedOmpOp (capturedOp, [&](Operation *sibling) {
2115
+ Operation *nestedCapture = findCapturedOmpOp (
2116
+ capturedOp, /* checkSingleMandatoryExec=*/ false ,
2117
+ [&](Operation *sibling) {
2113
2118
return sibling && (ompDialect != sibling->getDialect () ||
2114
2119
sibling->hasTrait <OpTrait::IsTerminator>());
2115
2120
});
0 commit comments