@@ -1915,8 +1915,8 @@ LogicalResult TargetOp::verifyRegions() {
1915
1915
return emitError (" target containing multiple 'omp.teams' nested ops" );
1916
1916
1917
1917
// Check that host_eval values are only used in legal ways.
1918
- llvm::omp::OMPTgtExecModeFlags execFlags =
1919
- getKernelExecFlags (getInnermostCapturedOmpOp () );
1918
+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1919
+ TargetRegionFlags execFlags = getKernelExecFlags (capturedOp );
1920
1920
for (Value hostEvalArg :
1921
1921
cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1922
1922
for (Operation *user : hostEvalArg.getUsers ()) {
@@ -1931,7 +1931,8 @@ LogicalResult TargetOp::verifyRegions() {
1931
1931
" and 'thread_limit' in 'omp.teams'" ;
1932
1932
}
1933
1933
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1934
- if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1934
+ if (bitEnumContainsAny (execFlags, TargetRegionFlags::spmd) &&
1935
+ parallelOp->isAncestor (capturedOp) &&
1935
1936
hostEvalArg == parallelOp.getNumThreads ())
1936
1937
continue ;
1937
1938
@@ -1940,15 +1941,16 @@ LogicalResult TargetOp::verifyRegions() {
1940
1941
" 'omp.parallel' when representing target SPMD" ;
1941
1942
}
1942
1943
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1943
- if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1944
+ if (bitEnumContainsAny (execFlags, TargetRegionFlags::trip_count) &&
1945
+ loopNestOp.getOperation () == capturedOp &&
1944
1946
(llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1945
1947
llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1946
1948
llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1947
1949
continue ;
1948
1950
1949
1951
return emitOpError () << " host_eval argument only legal as loop bounds "
1950
- " and steps in 'omp.loop_nest' when "
1951
- " representing target SPMD or Generic-SPMD " ;
1952
+ " and steps in 'omp.loop_nest' when trip count "
1953
+ " must be evaluated in the host " ;
1952
1954
}
1953
1955
1954
1956
return emitOpError () << " host_eval argument illegal use in '"
@@ -1958,42 +1960,21 @@ LogicalResult TargetOp::verifyRegions() {
1958
1960
return success ();
1959
1961
}
1960
1962
1961
- // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1962
- // / effects, but don't include a memory write effect.
1963
- static bool siblingAllowedInCapture (Operation *op) {
1964
- if (!op)
1965
- return false ;
1963
+ static Operation *
1964
+ findCapturedOmpOp (Operation *rootOp, bool checkSingleMandatoryExec,
1965
+ llvm::function_ref<bool (Operation *)> siblingAllowedFn) {
1966
+ assert (rootOp && " expected valid operation" );
1966
1967
1967
- bool isOmpDialect =
1968
- op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1969
- op->getDialect ();
1970
-
1971
- if (isOmpDialect)
1972
- return op->hasTrait <OpTrait::IsTerminator>();
1973
-
1974
- if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1975
- SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1976
- memOp.getEffects (effects);
1977
- return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1978
- return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1979
- isa<SideEffects::AutomaticAllocationScopeResource>(
1980
- effect.getResource ());
1981
- });
1982
- }
1983
- return true ;
1984
- }
1985
-
1986
- Operation *TargetOp::getInnermostCapturedOmpOp () {
1987
- Dialect *ompDialect = (*this )->getDialect ();
1968
+ Dialect *ompDialect = rootOp->getDialect ();
1988
1969
Operation *capturedOp = nullptr ;
1989
1970
DominanceInfo domInfo;
1990
1971
1991
1972
// Process in pre-order to check operations from outermost to innermost,
1992
1973
// ensuring we only enter the region of an operation if it meets the criteria
1993
1974
// for being captured. We stop the exploration of nested operations as soon as
1994
1975
// we process a region holding no operations to be captured.
1995
- walk<WalkOrder::PreOrder>([&](Operation *op) {
1996
- if (op == * this )
1976
+ rootOp-> walk <WalkOrder::PreOrder>([&](Operation *op) {
1977
+ if (op == rootOp )
1997
1978
return WalkResult::advance ();
1998
1979
1999
1980
// Ignore operations of other dialects or omp operations with no regions,
@@ -2008,22 +1989,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2008
1989
// (i.e. its block's successors can reach it) or if it's not guaranteed to
2009
1990
// be executed before all exits of the region (i.e. it doesn't dominate all
2010
1991
// blocks with no successors reachable from the entry block).
2011
- Region *parentRegion = op->getParentRegion ();
2012
- Block *parentBlock = op->getBlock ();
2013
-
2014
- for (Block *successor : parentBlock->getSuccessors ())
2015
- if (successor->isReachable (parentBlock))
2016
- return WalkResult::interrupt ();
2017
-
2018
- for (Block &block : *parentRegion)
2019
- if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2020
- !domInfo.dominates (parentBlock, &block))
2021
- return WalkResult::interrupt ();
1992
+ if (checkSingleMandatoryExec) {
1993
+ Region *parentRegion = op->getParentRegion ();
1994
+ Block *parentBlock = op->getBlock ();
1995
+
1996
+ for (Block *successor : parentBlock->getSuccessors ())
1997
+ if (successor->isReachable (parentBlock))
1998
+ return WalkResult::interrupt ();
1999
+
2000
+ for (Block &block : *parentRegion)
2001
+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
2002
+ !domInfo.dominates (parentBlock, &block))
2003
+ return WalkResult::interrupt ();
2004
+ }
2022
2005
2023
2006
// Don't capture this op if it has a not-allowed sibling, and stop recursing
2024
2007
// into nested operations.
2025
2008
for (Operation &sibling : op->getParentRegion ()->getOps ())
2026
- if (&sibling != op && !siblingAllowedInCapture (&sibling))
2009
+ if (&sibling != op && !siblingAllowedFn (&sibling))
2027
2010
return WalkResult::interrupt ();
2028
2011
2029
2012
// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2036,10 +2019,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
2036
2019
return capturedOp;
2037
2020
}
2038
2021
2039
- llvm::omp::OMPTgtExecModeFlags
2040
- TargetOp::getKernelExecFlags (Operation *capturedOp) {
2041
- using namespace llvm ::omp;
2022
+ Operation *TargetOp::getInnermostCapturedOmpOp () {
2023
+ auto *ompDialect = getContext ()->getLoadedDialect <omp::OpenMPDialect>();
2024
+
2025
+ // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2026
+ // effects, but don't include a memory write effect.
2027
+ return findCapturedOmpOp (
2028
+ *this , /* checkSingleMandatoryExec=*/ true , [&](Operation *sibling) {
2029
+ if (!sibling)
2030
+ return false ;
2031
+
2032
+ if (ompDialect == sibling->getDialect ())
2033
+ return sibling->hasTrait <OpTrait::IsTerminator>();
2034
+
2035
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2036
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 >
2037
+ effects;
2038
+ memOp.getEffects (effects);
2039
+ return !llvm::any_of (
2040
+ effects, [&](MemoryEffects::EffectInstance &effect) {
2041
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
2042
+ isa<SideEffects::AutomaticAllocationScopeResource>(
2043
+ effect.getResource ());
2044
+ });
2045
+ }
2046
+ return true ;
2047
+ });
2048
+ }
2042
2049
2050
+ TargetRegionFlags TargetOp::getKernelExecFlags (Operation *capturedOp) {
2043
2051
// A non-null captured op is only valid if it resides inside of a TargetOp
2044
2052
// and is the result of calling getInnermostCapturedOmpOp() on it.
2045
2053
TargetOp targetOp =
@@ -2048,60 +2056,102 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
2048
2056
(targetOp && targetOp.getInnermostCapturedOmpOp () == capturedOp)) &&
2049
2057
" unexpected captured op" );
2050
2058
2051
- // Make sure this region is capturing a loop. Otherwise, it's a generic
2052
- // kernel.
2059
+ // If it's not capturing a loop, it's a default target region.
2053
2060
if (!isa_and_present<LoopNestOp>(capturedOp))
2054
- return OMP_TGT_EXEC_MODE_GENERIC ;
2061
+ return TargetRegionFlags::generic ;
2055
2062
2056
- SmallVector<LoopWrapperInterface> wrappers;
2057
- cast<LoopNestOp>(capturedOp).gatherWrappers (wrappers);
2058
- assert (!wrappers.empty ());
2063
+ auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
2064
+ SmallVector<LoopWrapperInterface> wrappers;
2065
+ loopOp.gatherWrappers (wrappers);
2066
+ assert (!wrappers.empty ());
2059
2067
2060
- // Ignore optional SIMD leaf construct.
2061
- auto *innermostWrapper = wrappers.begin ();
2062
- if (isa<SimdOp>(innermostWrapper ))
2063
- innermostWrapper = std::next (innermostWrapper );
2068
+ // Ignore optional SIMD leaf construct.
2069
+ auto *wrapper = wrappers.begin ();
2070
+ if (isa<SimdOp>(wrapper ))
2071
+ wrapper = std::next (wrapper );
2064
2072
2065
- long numWrappers = std::distance (innermostWrapper, wrappers.end ());
2073
+ numWrappers = static_cast <int >(std::distance (wrapper, wrappers.end ()));
2074
+ return wrapper;
2075
+ };
2066
2076
2067
- // Detect Generic-SPMD: target-teams-distribute[-simd].
2068
- // Detect SPMD: target-teams-loop.
2069
- if (numWrappers == 1 ) {
2070
- if (!isa<DistributeOp, LoopOp>(innermostWrapper))
2071
- return OMP_TGT_EXEC_MODE_GENERIC;
2077
+ int numWrappers;
2078
+ LoopWrapperInterface *innermostWrapper =
2079
+ getInnermostWrapper (cast<LoopNestOp>(capturedOp), numWrappers);
2072
2080
2073
- Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2074
- if (!isa_and_present<TeamsOp>(teamsOp))
2075
- return OMP_TGT_EXEC_MODE_GENERIC;
2081
+ if (numWrappers != 1 && numWrappers != 2 )
2082
+ return TargetRegionFlags::generic;
2076
2083
2077
- if (teamsOp->getParentOp () == targetOp.getOperation ())
2078
- return isa<DistributeOp>(innermostWrapper)
2079
- ? OMP_TGT_EXEC_MODE_GENERIC_SPMD
2080
- : OMP_TGT_EXEC_MODE_SPMD;
2081
- }
2082
-
2083
- // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
2084
+ // Detect target-teams-distribute-parallel-wsloop[-simd].
2084
2085
if (numWrappers == 2 ) {
2085
2086
if (!isa<WsloopOp>(innermostWrapper))
2086
- return OMP_TGT_EXEC_MODE_GENERIC ;
2087
+ return TargetRegionFlags::generic ;
2087
2088
2088
2089
innermostWrapper = std::next (innermostWrapper);
2089
2090
if (!isa<DistributeOp>(innermostWrapper))
2090
- return OMP_TGT_EXEC_MODE_GENERIC ;
2091
+ return TargetRegionFlags::generic ;
2091
2092
2092
2093
Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2093
2094
if (!isa_and_present<ParallelOp>(parallelOp))
2094
- return OMP_TGT_EXEC_MODE_GENERIC ;
2095
+ return TargetRegionFlags::generic ;
2095
2096
2096
2097
Operation *teamsOp = parallelOp->getParentOp ();
2097
2098
if (!isa_and_present<TeamsOp>(teamsOp))
2098
- return OMP_TGT_EXEC_MODE_GENERIC ;
2099
+ return TargetRegionFlags::generic ;
2099
2100
2100
2101
if (teamsOp->getParentOp () == targetOp.getOperation ())
2101
- return OMP_TGT_EXEC_MODE_SPMD;
2102
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2103
+ }
2104
+ // Detect target-teams-distribute[-simd] and target-teams-loop.
2105
+ else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2106
+ Operation *teamsOp = (*innermostWrapper)->getParentOp ();
2107
+ if (!isa_and_present<TeamsOp>(teamsOp))
2108
+ return TargetRegionFlags::generic;
2109
+
2110
+ if (teamsOp->getParentOp () != targetOp.getOperation ())
2111
+ return TargetRegionFlags::generic;
2112
+
2113
+ if (isa<LoopOp>(innermostWrapper))
2114
+ return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2115
+
2116
+ // Find single immediately nested captured omp.parallel and add spmd flag
2117
+ // (generic-spmd case).
2118
+ //
2119
+ // TODO: This shouldn't have to be done here, as it is too easy to break.
2120
+ // The openmp-opt pass should be updated to be able to promote kernels like
2121
+ // this from "Generic" to "Generic-SPMD". However, the use of the
2122
+ // `kmpc_distribute_static_loop` family of functions produced by the
2123
+ // OMPIRBuilder for these kernels prevents that from working.
2124
+ Dialect *ompDialect = targetOp->getDialect ();
2125
+ Operation *nestedCapture = findCapturedOmpOp (
2126
+ capturedOp, /* checkSingleMandatoryExec=*/ false ,
2127
+ [&](Operation *sibling) {
2128
+ return sibling && (ompDialect != sibling->getDialect () ||
2129
+ sibling->hasTrait <OpTrait::IsTerminator>());
2130
+ });
2131
+
2132
+ TargetRegionFlags result =
2133
+ TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2134
+
2135
+ if (!nestedCapture)
2136
+ return result;
2137
+
2138
+ while (nestedCapture->getParentOp () != capturedOp)
2139
+ nestedCapture = nestedCapture->getParentOp ();
2140
+
2141
+ return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2142
+ : result;
2143
+ }
2144
+ // Detect target-parallel-wsloop[-simd].
2145
+ else if (isa<WsloopOp>(innermostWrapper)) {
2146
+ Operation *parallelOp = (*innermostWrapper)->getParentOp ();
2147
+ if (!isa_and_present<ParallelOp>(parallelOp))
2148
+ return TargetRegionFlags::generic;
2149
+
2150
+ if (parallelOp->getParentOp () == targetOp.getOperation ())
2151
+ return TargetRegionFlags::spmd;
2102
2152
}
2103
2153
2104
- return OMP_TGT_EXEC_MODE_GENERIC ;
2154
+ return TargetRegionFlags::generic ;
2105
2155
}
2106
2156
2107
2157
// ===----------------------------------------------------------------------===//
0 commit comments