Skip to content

Commit 27ffa9f

Browse files
committed
More robust kernel type detection
1 parent 4d20a42 commit 27ffa9f

File tree

5 files changed

+117
-45
lines changed

5 files changed

+117
-45
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/SymbolTable.h"
2323
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2424
#include "mlir/Interfaces/SideEffectInterfaces.h"
25+
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
2526

2627
#define GET_TYPEDEF_CLASSES
2728
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1271,9 +1271,9 @@ def TargetOp : OpenMP_Op<"target", traits = [
12711271
/// operations, the top level one will be the one captured.
12721272
Operation *getInnermostCapturedOmpOp();
12731273

1274-
/// Checks whether this target region represents the MLIR equivalent to a
1275-
/// 'target teams distribute parallel {do, for} [simd]' OpenMP construct.
1276-
bool isTargetSPMDLoop();
1274+
/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
1275+
/// contents of the target region.
1276+
llvm::omp::OMPTgtExecModeFlags getKernelExecFlags();
12771277
}] # clausesExtraClassDeclaration;
12781278

12791279
let assemblyFormat = clausesAssemblyFormat # [{

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 70 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "llvm/ADT/StringRef.h"
3232
#include "llvm/ADT/TypeSwitch.h"
3333
#include "llvm/Frontend/OpenMP/OMPConstants.h"
34+
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
3435
#include <cstddef>
3536
#include <iterator>
3637
#include <optional>
@@ -1754,7 +1755,7 @@ LogicalResult TargetOp::verifyRegions() {
17541755
return emitError("target containing multiple 'omp.teams' nested ops");
17551756

17561757
// Check that host_eval values are only used in legal ways.
1757-
bool isTargetSPMD = isTargetSPMDLoop();
1758+
llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
17581759
for (Value hostEvalArg :
17591760
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
17601761
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1769,23 +1770,24 @@ LogicalResult TargetOp::verifyRegions() {
17691770
"and 'thread_limit' in 'omp.teams'";
17701771
}
17711772
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1772-
if (isTargetSPMD && hostEvalArg == parallelOp.getNumThreads())
1773+
if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1774+
hostEvalArg == parallelOp.getNumThreads())
17731775
continue;
17741776

17751777
return emitOpError()
17761778
<< "host_eval argument only legal as 'num_threads' in "
17771779
"'omp.parallel' when representing target SPMD";
17781780
}
17791781
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1780-
if (isTargetSPMD &&
1782+
if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
17811783
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
17821784
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
17831785
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
17841786
continue;
17851787

1786-
return emitOpError()
1787-
<< "host_eval argument only legal as loop bounds and steps in "
1788-
"'omp.loop_nest' when representing target SPMD";
1788+
return emitOpError() << "host_eval argument only legal as loop bounds "
1789+
"and steps in 'omp.loop_nest' when "
1790+
"representing target SPMD or Generic-SPMD";
17891791
}
17901792

17911793
return emitOpError() << "host_eval argument illegal use in '"
@@ -1823,6 +1825,7 @@ static bool siblingAllowedInCapture(Operation *op) {
18231825
Operation *TargetOp::getInnermostCapturedOmpOp() {
18241826
Dialect *ompDialect = (*this)->getDialect();
18251827
Operation *capturedOp = nullptr;
1828+
DominanceInfo domInfo;
18261829

18271830
// Process in pre-order to check operations from outermost to innermost,
18281831
// ensuring we only enter the region of an operation if it meets the criteria
@@ -1840,6 +1843,22 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
18401843
if (!isOmpDialect || !hasRegions)
18411844
return WalkResult::skip();
18421845

1846+
// This operation cannot be captured if it can be executed more than once
1847+
// (i.e. its block's successors can reach it) or if it's not guaranteed to
1848+
// be executed before all exits of the region (i.e. it doesn't dominate all
1849+
// blocks with no successors reachable from the entry block).
1850+
Region *parentRegion = op->getParentRegion();
1851+
Block *parentBlock = op->getBlock();
1852+
1853+
for (Block *successor : parentBlock->getSuccessors())
1854+
if (successor->isReachable(parentBlock))
1855+
return WalkResult::interrupt();
1856+
1857+
for (Block &block : *parentRegion)
1858+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
1859+
!domInfo.dominates(parentBlock, &block))
1860+
return WalkResult::interrupt();
1861+
18431862
// Don't capture this op if it has a not-allowed sibling, and stop recursing
18441863
// into nested operations.
18451864
for (Operation &sibling : op->getParentRegion()->getOps())
@@ -1856,49 +1875,61 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
18561875
return capturedOp;
18571876
}
18581877

1859-
bool TargetOp::isTargetSPMDLoop() {
1860-
// The expected MLIR representation for a target SPMD loop is:
1861-
// omp.target {
1862-
// omp.teams {
1863-
// omp.parallel {
1864-
// omp.distribute {
1865-
// omp.wsloop {
1866-
// omp.loop_nest ... { ... }
1867-
// } {omp.composite}
1868-
// } {omp.composite}
1869-
// omp.terminator
1870-
// } {omp.composite}
1871-
// omp.terminator
1872-
// }
1873-
// omp.terminator
1874-
// }
1878+
llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
1879+
using namespace llvm::omp;
18751880

1881+
// Make sure this region is capturing a loop. Otherwise, it's a generic
1882+
// kernel.
18761883
Operation *capturedOp = getInnermostCapturedOmpOp();
18771884
if (!isa_and_present<LoopNestOp>(capturedOp))
1878-
return false;
1885+
return OMP_TGT_EXEC_MODE_GENERIC;
18791886

1880-
Operation *workshareOp = capturedOp->getParentOp();
1887+
SmallVector<LoopWrapperInterface> wrappers;
1888+
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
1889+
assert(!wrappers.empty());
18811890

1882-
// Accept an optional omp.simd loop wrapper as part of the SPMD pattern.
1883-
if (isa_and_present<SimdOp>(workshareOp))
1884-
workshareOp = workshareOp->getParentOp();
1891+
// Ignore optional SIMD leaf construct.
1892+
auto *innermostWrapper = wrappers.begin();
1893+
if (isa<SimdOp>(innermostWrapper))
1894+
innermostWrapper = std::next(innermostWrapper);
18851895

1886-
if (!isa_and_present<WsloopOp>(workshareOp))
1887-
return false;
1896+
long numWrappers = std::distance(innermostWrapper, wrappers.end());
18881897

1889-
Operation *distributeOp = workshareOp->getParentOp();
1890-
if (!isa_and_present<DistributeOp>(distributeOp))
1891-
return false;
1898+
// Detect Generic-SPMD: target-teams-distribute[-simd].
1899+
if (numWrappers == 1) {
1900+
if (!isa<DistributeOp>(innermostWrapper))
1901+
return OMP_TGT_EXEC_MODE_GENERIC;
18921902

1893-
Operation *parallelOp = distributeOp->getParentOp();
1894-
if (!isa_and_present<ParallelOp>(parallelOp))
1895-
return false;
1903+
Operation *teamsOp = (*innermostWrapper)->getParentOp();
1904+
if (!isa_and_present<TeamsOp>(teamsOp))
1905+
return OMP_TGT_EXEC_MODE_GENERIC;
18961906

1897-
Operation *teamsOp = parallelOp->getParentOp();
1898-
if (!isa_and_present<TeamsOp>(teamsOp))
1899-
return false;
1907+
if (teamsOp->getParentOp() == *this)
1908+
return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
1909+
}
1910+
1911+
// Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
1912+
if (numWrappers == 2) {
1913+
if (!isa<WsloopOp>(innermostWrapper))
1914+
return OMP_TGT_EXEC_MODE_GENERIC;
1915+
1916+
innermostWrapper = std::next(innermostWrapper);
1917+
if (!isa<DistributeOp>(innermostWrapper))
1918+
return OMP_TGT_EXEC_MODE_GENERIC;
1919+
1920+
Operation *parallelOp = (*innermostWrapper)->getParentOp();
1921+
if (!isa_and_present<ParallelOp>(parallelOp))
1922+
return OMP_TGT_EXEC_MODE_GENERIC;
1923+
1924+
Operation *teamsOp = parallelOp->getParentOp();
1925+
if (!isa_and_present<TeamsOp>(teamsOp))
1926+
return OMP_TGT_EXEC_MODE_GENERIC;
1927+
1928+
if (teamsOp->getParentOp() == *this)
1929+
return OMP_TGT_EXEC_MODE_SPMD;
1930+
}
19001931

1901-
return teamsOp->getParentOp() == (*this);
1932+
return OMP_TGT_EXEC_MODE_GENERIC;
19021933
}
19031934

19041935
//===----------------------------------------------------------------------===//

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2191,8 +2191,8 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
21912191

21922192
// -----
21932193

2194-
func.func @omp_target_host_eval_loop(%x : i32) {
2195-
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD}}
2194+
func.func @omp_target_host_eval_loop1(%x : i32) {
2195+
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
21962196
omp.target host_eval(%x -> %arg0 : i32) {
21972197
omp.wsloop {
21982198
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2206,6 +2206,30 @@ func.func @omp_target_host_eval_loop(%x : i32) {
22062206

22072207
// -----
22082208

2209+
func.func @omp_target_host_eval_loop2(%x : i32) {
2210+
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
2211+
omp.target host_eval(%x -> %arg0 : i32) {
2212+
omp.teams {
2213+
^bb0:
2214+
%0 = arith.constant 0 : i1
2215+
llvm.cond_br %0, ^bb1, ^bb2
2216+
^bb1:
2217+
omp.distribute {
2218+
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
2219+
omp.yield
2220+
}
2221+
}
2222+
llvm.br ^bb2
2223+
^bb2:
2224+
omp.terminator
2225+
}
2226+
omp.terminator
2227+
}
2228+
return
2229+
}
2230+
2231+
// -----
2232+
22092233
func.func @omp_target_depend(%data_var: memref<i32>) {
22102234
// expected-error @below {{op expected as many depend values as depend variables}}
22112235
"omp.target"(%data_var) ({

mlir/test/Dialect/OpenMP/ops.mlir

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2786,7 +2786,7 @@ func.func @omp_target_host_eval(%x : i32) {
27862786
}
27872787

27882788
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
2789-
// CHECK: omp.teams
2789+
// CHECK: omp.teams {
27902790
// CHECK: omp.parallel num_threads(%[[HOST_ARG]] : i32) {
27912791
// CHECK: omp.distribute {
27922792
// CHECK: omp.wsloop {
@@ -2807,6 +2807,22 @@ func.func @omp_target_host_eval(%x : i32) {
28072807
}
28082808
omp.terminator
28092809
}
2810+
2811+
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
2812+
// CHECK: omp.teams {
2813+
// CHECK: omp.distribute {
2814+
// CHECK: omp.loop_nest (%{{.*}}) : i32 = (%[[HOST_ARG]]) to (%[[HOST_ARG]]) step (%[[HOST_ARG]]) {
2815+
omp.target host_eval(%x -> %arg0 : i32) {
2816+
omp.teams {
2817+
omp.distribute {
2818+
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
2819+
omp.yield
2820+
}
2821+
}
2822+
omp.terminator
2823+
}
2824+
omp.terminator
2825+
}
28102826
return
28112827
}
28122828

0 commit comments

Comments
 (0)