@@ -691,8 +691,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion,
691
691
return parser.parseRegion (region, entryBlockArgs);
692
692
}
693
693
694
- static ParseResult parseInReductionMapPrivateRegion (
694
+ static ParseResult parseHostEvalInReductionMapPrivateRegion (
695
695
OpAsmParser &parser, Region ®ion,
696
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
697
+ SmallVectorImpl<Type> &hostEvalTypes,
696
698
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
697
699
SmallVectorImpl<Type> &inReductionTypes,
698
700
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
@@ -702,6 +704,7 @@ static ParseResult parseInReductionMapPrivateRegion(
702
704
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
703
705
DenseI64ArrayAttr &privateMaps) {
704
706
AllRegionParseArgs args;
707
+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
705
708
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
706
709
inReductionByref, inReductionSyms);
707
710
args.mapArgs .emplace (mapVars, mapTypes);
@@ -931,13 +934,15 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion,
931
934
p.printRegion (region, /* printEntryBlockArgs=*/ false );
932
935
}
933
936
934
- static void printInReductionMapPrivateRegion (
935
- OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars,
937
+ static void printHostEvalInReductionMapPrivateRegion (
938
+ OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hostEvalVars,
939
+ TypeRange hostEvalTypes, ValueRange inReductionVars,
936
940
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
937
941
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
938
942
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
939
943
DenseI64ArrayAttr privateMaps) {
940
944
AllRegionPrintArgs args;
945
+ args.hostEvalArgs .emplace (hostEvalVars, hostEvalTypes);
941
946
args.inReductionArgs .emplace (inReductionVars, inReductionTypes,
942
947
inReductionByref, inReductionSyms);
943
948
args.mapArgs .emplace (mapVars, mapTypes);
@@ -1719,7 +1724,8 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
1719
1724
// inReductionByref, inReductionSyms.
1720
1725
TargetOp::build (builder, state, /* allocate_vars=*/ {}, /* allocator_vars=*/ {},
1721
1726
makeArrayAttr (ctx, clauses.dependKinds ), clauses.dependVars ,
1722
- clauses.device , clauses.hasDeviceAddrVars , clauses.ifExpr ,
1727
+ clauses.device , clauses.hasDeviceAddrVars ,
1728
+ clauses.hostEvalVars , clauses.ifExpr ,
1723
1729
/* in_reduction_vars=*/ {}, /* in_reduction_byref=*/ nullptr ,
1724
1730
/* in_reduction_syms=*/ nullptr , clauses.isDevicePtrVars ,
1725
1731
clauses.mapVars , clauses.nowait , clauses.privateVars ,
@@ -1742,6 +1748,159 @@ LogicalResult TargetOp::verify() {
1742
1748
return verifyPrivateVarsMapping (*this );
1743
1749
}
1744
1750
1751
+ LogicalResult TargetOp::verifyRegions () {
1752
+ auto teamsOps = getOps<TeamsOp>();
1753
+ if (std::distance (teamsOps.begin (), teamsOps.end ()) > 1 )
1754
+ return emitError (" target containing multiple 'omp.teams' nested ops" );
1755
+
1756
+ // Check that host_eval values are only used in legal ways.
1757
+ bool isTargetSPMD = isTargetSPMDLoop ();
1758
+ for (Value hostEvalArg :
1759
+ cast<BlockArgOpenMPOpInterface>(getOperation ()).getHostEvalBlockArgs ()) {
1760
+ for (Operation *user : hostEvalArg.getUsers ()) {
1761
+ if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
1762
+ if (llvm::is_contained ({teamsOp.getNumTeamsLower (),
1763
+ teamsOp.getNumTeamsUpper (),
1764
+ teamsOp.getThreadLimit ()},
1765
+ hostEvalArg))
1766
+ continue ;
1767
+
1768
+ return emitOpError () << " host_eval argument only legal as 'num_teams' "
1769
+ " and 'thread_limit' in 'omp.teams'" ;
1770
+ }
1771
+ if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1772
+ if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads ())
1773
+ continue ;
1774
+
1775
+ return emitOpError ()
1776
+ << " host_eval argument only legal as 'num_threads' in "
1777
+ " 'omp.parallel' when representing target SPMD" ;
1778
+ }
1779
+ if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1780
+ if (isTargetSPMD &&
1781
+ (llvm::is_contained (loopNestOp.getLoopLowerBounds (), hostEvalArg) ||
1782
+ llvm::is_contained (loopNestOp.getLoopUpperBounds (), hostEvalArg) ||
1783
+ llvm::is_contained (loopNestOp.getLoopSteps (), hostEvalArg)))
1784
+ continue ;
1785
+
1786
+ return emitOpError ()
1787
+ << " host_eval argument only legal as loop bounds and steps in "
1788
+ " 'omp.loop_nest' when representing target SPMD" ;
1789
+ }
1790
+
1791
+ return emitOpError () << " host_eval argument illegal use in '"
1792
+ << user->getName () << " ' operation" ;
1793
+ }
1794
+ }
1795
+ return success ();
1796
+ }
1797
+
1798
+ // / Only allow OpenMP terminators and non-OpenMP ops that have known memory
1799
+ // / effects, but don't include a memory write effect.
1800
+ static bool siblingAllowedInCapture (Operation *op) {
1801
+ if (!op)
1802
+ return false ;
1803
+
1804
+ bool isOmpDialect =
1805
+ op->getContext ()->getLoadedDialect <omp::OpenMPDialect>() ==
1806
+ op->getDialect ();
1807
+
1808
+ if (isOmpDialect)
1809
+ return op->hasTrait <OpTrait::IsTerminator>();
1810
+
1811
+ if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1812
+ SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4 > effects;
1813
+ memOp.getEffects (effects);
1814
+ return !llvm::any_of (effects, [&](MemoryEffects::EffectInstance &effect) {
1815
+ return isa<MemoryEffects::Write>(effect.getEffect ()) &&
1816
+ isa<SideEffects::AutomaticAllocationScopeResource>(
1817
+ effect.getResource ());
1818
+ });
1819
+ }
1820
+ return true ;
1821
+ }
1822
+
1823
+ Operation *TargetOp::getInnermostCapturedOmpOp () {
1824
+ Dialect *ompDialect = (*this )->getDialect ();
1825
+ Operation *capturedOp = nullptr ;
1826
+
1827
+ // Process in pre-order to check operations from outermost to innermost,
1828
+ // ensuring we only enter the region of an operation if it meets the criteria
1829
+ // for being captured. We stop the exploration of nested operations as soon as
1830
+ // we process a region holding no operations to be captured.
1831
+ walk<WalkOrder::PreOrder>([&](Operation *op) {
1832
+ if (op == *this )
1833
+ return WalkResult::advance ();
1834
+
1835
+ // Ignore operations of other dialects or omp operations with no regions,
1836
+ // because these will only be checked if they are siblings of an omp
1837
+ // operation that can potentially be captured.
1838
+ bool isOmpDialect = op->getDialect () == ompDialect;
1839
+ bool hasRegions = op->getNumRegions () > 0 ;
1840
+ if (!isOmpDialect || !hasRegions)
1841
+ return WalkResult::skip ();
1842
+
1843
+ // Don't capture this op if it has a not-allowed sibling, and stop recursing
1844
+ // into nested operations.
1845
+ for (Operation &sibling : op->getParentRegion ()->getOps ())
1846
+ if (&sibling != op && !siblingAllowedInCapture (&sibling))
1847
+ return WalkResult::interrupt ();
1848
+
1849
+ // Don't continue capturing nested operations if we reach an omp.loop_nest.
1850
+ // Otherwise, process the contents of this operation.
1851
+ capturedOp = op;
1852
+ return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt ()
1853
+ : WalkResult::advance ();
1854
+ });
1855
+
1856
+ return capturedOp;
1857
+ }
1858
+
1859
+ bool TargetOp::isTargetSPMDLoop () {
1860
+ // The expected MLIR representation for a target SPMD loop is:
1861
+ // omp.target {
1862
+ // omp.teams {
1863
+ // omp.parallel {
1864
+ // omp.distribute {
1865
+ // omp.wsloop {
1866
+ // omp.loop_nest ... { ... }
1867
+ // } {omp.composite}
1868
+ // } {omp.composite}
1869
+ // omp.terminator
1870
+ // } {omp.composite}
1871
+ // omp.terminator
1872
+ // }
1873
+ // omp.terminator
1874
+ // }
1875
+
1876
+ Operation *capturedOp = getInnermostCapturedOmpOp ();
1877
+ if (!isa_and_present<LoopNestOp>(capturedOp))
1878
+ return false ;
1879
+
1880
+ Operation *workshareOp = capturedOp->getParentOp ();
1881
+
1882
+ // Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
1883
+ if (isa_and_present<SimdOp>(workshareOp))
1884
+ workshareOp = workshareOp->getParentOp ();
1885
+
1886
+ if (!isa_and_present<WsloopOp>(workshareOp))
1887
+ return false ;
1888
+
1889
+ Operation *distributeOp = workshareOp->getParentOp ();
1890
+ if (!isa_and_present<DistributeOp>(distributeOp))
1891
+ return false ;
1892
+
1893
+ Operation *parallelOp = distributeOp->getParentOp ();
1894
+ if (!isa_and_present<ParallelOp>(parallelOp))
1895
+ return false ;
1896
+
1897
+ Operation *teamsOp = parallelOp->getParentOp ();
1898
+ if (!isa_and_present<TeamsOp>(teamsOp))
1899
+ return false ;
1900
+
1901
+ return teamsOp->getParentOp () == (*this );
1902
+ }
1903
+
1745
1904
// ===----------------------------------------------------------------------===//
1746
1905
// ParallelOp
1747
1906
// ===----------------------------------------------------------------------===//
0 commit comments