31
31
#include " llvm/ADT/StringRef.h"
32
32
#include " llvm/ADT/TypeSwitch.h"
33
33
#include " llvm/Frontend/OpenMP/OMPConstants.h"
34
+ #include " llvm/Frontend/OpenMP/OMPDeviceConstants.h"
34
35
#include < cstddef>
35
36
#include < iterator>
36
37
#include < optional>
@@ -1754,7 +1755,7 @@ LogicalResult TargetOp::verifyRegions() {
1754
1755
return emitError (" target containing multiple 'omp.teams' nested ops" );
1755
1756
1756
1757
// Check that host_eval values are only used in legal ways.
1757
- bool isTargetSPMD = isTargetSPMDLoop ();
1758
+ llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags ();
1758
1759
for (Value hostEvalArg :
1759
1760
cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1760
1761
for (Operation *user : hostEvalArg.getUsers ()) {
@@ -1769,23 +1770,24 @@ LogicalResult TargetOp::verifyRegions() {
1769
1770
" and 'thread_limit' in 'omp.teams'" ;
1770
1771
}
1771
1772
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1772
- if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads ())
1773
+ if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1774
+ hostEvalArg == parallelOp.getNumThreads ())
1773
1775
continue ;
1774
1776
1775
1777
return emitOpError ()
1776
1778
<< " host_eval argument only legal as 'num_threads' in "
1777
1779
" 'omp.parallel' when representing target SPMD" ;
1778
1780
}
1779
1781
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1780
- if (isTargetSPMD &&
1782
+ if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1781
1783
(llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1782
1784
llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1783
1785
llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1784
1786
continue ;
1785
1787
1786
- return emitOpError ()
1787
- << " host_eval argument only legal as loop bounds and steps in "
1788
- " 'omp.loop_nest' when representing target SPMD" ;
1788
+ return emitOpError () << " host_eval argument only legal as loop bounds "
1789
+ " and steps in 'omp.loop_nest' when "
1790
+ " representing target SPMD or Generic- SPMD" ;
1789
1791
}
1790
1792
1791
1793
return emitOpError () << " host_eval argument illegal use in '"
@@ -1823,6 +1825,7 @@ static bool siblingAllowedInCapture(Operation *op) {
1823
1825
Operation *TargetOp::getInnermostCapturedOmpOp () {
1824
1826
Dialect *ompDialect = (*this )->getDialect ();
1825
1827
Operation *capturedOp = nullptr ;
1828
+ DominanceInfo domInfo;
1826
1829
1827
1830
// Process in pre-order to check operations from outermost to innermost,
1828
1831
// ensuring we only enter the region of an operation if it meets the criteria
@@ -1840,6 +1843,22 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
1840
1843
if (!isOmpDialect || !hasRegions)
1841
1844
return WalkResult::skip ();
1842
1845
1846
+ // This operation cannot be captured if it can be executed more than once
1847
+ // (i.e. its block's successors can reach it) or if it's not guaranteed to
1848
+ // be executed before all exits of the region (i.e. it doesn't dominate all
1849
+ // blocks with no successors reachable from the entry block).
1850
+ Region *parentRegion = op->getParentRegion ();
1851
+ Block *parentBlock = op->getBlock ();
1852
+
1853
+ for (Block *successor : parentBlock->getSuccessors ())
1854
+ if (successor->isReachable (parentBlock))
1855
+ return WalkResult::interrupt ();
1856
+
1857
+ for (Block &block : *parentRegion)
1858
+ if (domInfo.isReachableFromEntry (&block) && block.hasNoSuccessors () &&
1859
+ !domInfo.dominates (parentBlock, &block))
1860
+ return WalkResult::interrupt ();
1861
+
1843
1862
// Don't capture this op if it has a not-allowed sibling, and stop recursing
1844
1863
// into nested operations.
1845
1864
for (Operation &sibling : op->getParentRegion ()->getOps ())
@@ -1856,49 +1875,61 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
1856
1875
return capturedOp;
1857
1876
}
1858
1877
1859
- bool TargetOp::isTargetSPMDLoop () {
1860
- // The expected MLIR representation for a target SPMD loop is:
1861
- // omp.target {
1862
- // omp.teams {
1863
- // omp.parallel {
1864
- // omp.distribute {
1865
- // omp.wsloop {
1866
- // omp.loop_nest ... { ... }
1867
- // } {omp.composite}
1868
- // } {omp.composite}
1869
- // omp.terminator
1870
- // } {omp.composite}
1871
- // omp.terminator
1872
- // }
1873
- // omp.terminator
1874
- // }
1878
+ llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags () {
1879
+ using namespace llvm ::omp;
1875
1880
1881
+ // Make sure this region is capturing a loop. Otherwise, it's a generic
1882
+ // kernel.
1876
1883
Operation *capturedOp = getInnermostCapturedOmpOp ();
1877
1884
if (!isa_and_present<LoopNestOp>(capturedOp))
1878
- return false ;
1885
+ return OMP_TGT_EXEC_MODE_GENERIC ;
1879
1886
1880
- Operation *workshareOp = capturedOp->getParentOp ();
1887
+ SmallVector<LoopWrapperInterface> wrappers;
1888
+ cast<LoopNestOp>(capturedOp).gatherWrappers (wrappers);
1889
+ assert (!wrappers.empty ());
1881
1890
1882
- // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
1883
- if (isa_and_present<SimdOp>(workshareOp))
1884
- workshareOp = workshareOp->getParentOp ();
1891
+ // Ignore optional SIMD leaf construct.
1892
+ auto *innermostWrapper = wrappers.begin ();
1893
+ if (isa<SimdOp>(innermostWrapper))
1894
+ innermostWrapper = std::next (innermostWrapper);
1885
1895
1886
- if (!isa_and_present<WsloopOp>(workshareOp))
1887
- return false ;
1896
+ long numWrappers = std::distance (innermostWrapper, wrappers.end ());
1888
1897
1889
- Operation *distributeOp = workshareOp->getParentOp ();
1890
- if (!isa_and_present<DistributeOp>(distributeOp))
1891
- return false ;
1898
+ // Detect Generic-SPMD: target-teams-distribute[-simd].
1899
+ if (numWrappers == 1 ) {
1900
+ if (!isa<DistributeOp>(innermostWrapper))
1901
+ return OMP_TGT_EXEC_MODE_GENERIC;
1892
1902
1893
- Operation *parallelOp = distributeOp ->getParentOp ();
1894
- if (!isa_and_present<ParallelOp>(parallelOp ))
1895
- return false ;
1903
+ Operation *teamsOp = (*innermostWrapper) ->getParentOp ();
1904
+ if (!isa_and_present<TeamsOp>(teamsOp ))
1905
+ return OMP_TGT_EXEC_MODE_GENERIC ;
1896
1906
1897
- Operation *teamsOp = parallelOp->getParentOp ();
1898
- if (!isa_and_present<TeamsOp>(teamsOp))
1899
- return false ;
1907
+ if (teamsOp->getParentOp () == *this )
1908
+ return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1909
+ }
1910
+
1911
+ // Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
1912
+ if (numWrappers == 2 ) {
1913
+ if (!isa<WsloopOp>(innermostWrapper))
1914
+ return OMP_TGT_EXEC_MODE_GENERIC;
1915
+
1916
+ innermostWrapper = std::next (innermostWrapper);
1917
+ if (!isa<DistributeOp>(innermostWrapper))
1918
+ return OMP_TGT_EXEC_MODE_GENERIC;
1919
+
1920
+ Operation *parallelOp = (*innermostWrapper)->getParentOp ();
1921
+ if (!isa_and_present<ParallelOp>(parallelOp))
1922
+ return OMP_TGT_EXEC_MODE_GENERIC;
1923
+
1924
+ Operation *teamsOp = parallelOp->getParentOp ();
1925
+ if (!isa_and_present<TeamsOp>(teamsOp))
1926
+ return OMP_TGT_EXEC_MODE_GENERIC;
1927
+
1928
+ if (teamsOp->getParentOp () == *this )
1929
+ return OMP_TGT_EXEC_MODE_SPMD;
1930
+ }
1900
1931
1901
- return teamsOp-> getParentOp () == (* this ) ;
1932
+ return OMP_TGT_EXEC_MODE_GENERIC ;
1902
1933
}
1903
1934
1904
1935
// ===----------------------------------------------------------------------===//
0 commit comments