Skip to content

Commit 7434b7a

Browse files
authored
[MLIR][OpenMP] Fix standalone distribute on the device (llvm#1342)
2 parents b154bfa + 468029e commit 7434b7a

File tree

7 files changed

+293
-87
lines changed

7 files changed

+293
-87
lines changed

mlir/include/mlir/Dialect/OpenMP/OpenMPEnums.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,24 @@ def ScheduleModifier : OpenMP_I32EnumAttr<
224224

225225
def ScheduleModifierAttr : OpenMP_EnumAttr<ScheduleModifier, "sched_mod">;
226226

227+
//===----------------------------------------------------------------------===//
228+
// target_region_flags enum.
229+
//===----------------------------------------------------------------------===//
230+
231+
def TargetRegionFlagsNone : I32BitEnumAttrCaseNone<"none">;
232+
def TargetRegionFlagsGeneric : I32BitEnumAttrCaseBit<"generic", 0>;
233+
def TargetRegionFlagsSpmd : I32BitEnumAttrCaseBit<"spmd", 1>;
234+
def TargetRegionFlagsTripCount : I32BitEnumAttrCaseBit<"trip_count", 2>;
235+
236+
def TargetRegionFlags : OpenMP_BitEnumAttr<
237+
"TargetRegionFlags",
238+
"target region property flags", [
239+
TargetRegionFlagsNone,
240+
TargetRegionFlagsGeneric,
241+
TargetRegionFlagsSpmd,
242+
TargetRegionFlagsTripCount
243+
]>;
244+
227245
//===----------------------------------------------------------------------===//
228246
// variable_capture_kind enum.
229247
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1312,7 +1312,7 @@ def TargetOp : OpenMP_Op<"target", traits = [
13121312
///
13131313
/// \param capturedOp result of a still valid (no modifications made to any
13141314
/// nested operations) previous call to `getInnermostCapturedOmpOp()`.
1315-
static llvm::omp::OMPTgtExecModeFlags
1315+
static ::mlir::omp::TargetRegionFlags
13161316
getKernelExecFlags(Operation *capturedOp);
13171317
}] # clausesExtraClassDeclaration;
13181318

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 131 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1915,8 +1915,8 @@ LogicalResult TargetOp::verifyRegions() {
19151915
return emitError("target containing multiple 'omp.teams' nested ops");
19161916

19171917
// Check that host_eval values are only used in legal ways.
1918-
llvm::omp::OMPTgtExecModeFlags execFlags =
1919-
getKernelExecFlags(getInnermostCapturedOmpOp());
1918+
Operation *capturedOp = getInnermostCapturedOmpOp();
1919+
TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
19201920
for (Value hostEvalArg :
19211921
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
19221922
for (Operation *user : hostEvalArg.getUsers()) {
@@ -1931,7 +1931,8 @@ LogicalResult TargetOp::verifyRegions() {
19311931
"and 'thread_limit' in 'omp.teams'";
19321932
}
19331933
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
1934-
if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
1934+
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
1935+
parallelOp->isAncestor(capturedOp) &&
19351936
hostEvalArg == parallelOp.getNumThreads())
19361937
continue;
19371938

@@ -1940,15 +1941,16 @@ LogicalResult TargetOp::verifyRegions() {
19401941
"'omp.parallel' when representing target SPMD";
19411942
}
19421943
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
1943-
if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
1944+
if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
1945+
loopNestOp.getOperation() == capturedOp &&
19441946
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
19451947
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
19461948
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
19471949
continue;
19481950

19491951
return emitOpError() << "host_eval argument only legal as loop bounds "
1950-
"and steps in 'omp.loop_nest' when "
1951-
"representing target SPMD or Generic-SPMD";
1952+
"and steps in 'omp.loop_nest' when trip count "
1953+
"must be evaluated in the host";
19521954
}
19531955

19541956
return emitOpError() << "host_eval argument illegal use in '"
@@ -1958,42 +1960,21 @@ LogicalResult TargetOp::verifyRegions() {
19581960
return success();
19591961
}
19601962

1961-
/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
1962-
/// effects, but don't include a memory write effect.
1963-
static bool siblingAllowedInCapture(Operation *op) {
1964-
if (!op)
1965-
return false;
1963+
static Operation *
1964+
findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
1965+
llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
1966+
assert(rootOp && "expected valid operation");
19661967

1967-
bool isOmpDialect =
1968-
op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
1969-
op->getDialect();
1970-
1971-
if (isOmpDialect)
1972-
return op->hasTrait<OpTrait::IsTerminator>();
1973-
1974-
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
1975-
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
1976-
memOp.getEffects(effects);
1977-
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
1978-
return isa<MemoryEffects::Write>(effect.getEffect()) &&
1979-
isa<SideEffects::AutomaticAllocationScopeResource>(
1980-
effect.getResource());
1981-
});
1982-
}
1983-
return true;
1984-
}
1985-
1986-
Operation *TargetOp::getInnermostCapturedOmpOp() {
1987-
Dialect *ompDialect = (*this)->getDialect();
1968+
Dialect *ompDialect = rootOp->getDialect();
19881969
Operation *capturedOp = nullptr;
19891970
DominanceInfo domInfo;
19901971

19911972
// Process in pre-order to check operations from outermost to innermost,
19921973
// ensuring we only enter the region of an operation if it meets the criteria
19931974
// for being captured. We stop the exploration of nested operations as soon as
19941975
// we process a region holding no operations to be captured.
1995-
walk<WalkOrder::PreOrder>([&](Operation *op) {
1996-
if (op == *this)
1976+
rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
1977+
if (op == rootOp)
19971978
return WalkResult::advance();
19981979

19991980
// Ignore operations of other dialects or omp operations with no regions,
@@ -2008,22 +1989,24 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20081989
// (i.e. its block's successors can reach it) or if it's not guaranteed to
20091990
// be executed before all exits of the region (i.e. it doesn't dominate all
20101991
// blocks with no successors reachable from the entry block).
2011-
Region *parentRegion = op->getParentRegion();
2012-
Block *parentBlock = op->getBlock();
2013-
2014-
for (Block *successor : parentBlock->getSuccessors())
2015-
if (successor->isReachable(parentBlock))
2016-
return WalkResult::interrupt();
2017-
2018-
for (Block &block : *parentRegion)
2019-
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2020-
!domInfo.dominates(parentBlock, &block))
2021-
return WalkResult::interrupt();
1992+
if (checkSingleMandatoryExec) {
1993+
Region *parentRegion = op->getParentRegion();
1994+
Block *parentBlock = op->getBlock();
1995+
1996+
for (Block *successor : parentBlock->getSuccessors())
1997+
if (successor->isReachable(parentBlock))
1998+
return WalkResult::interrupt();
1999+
2000+
for (Block &block : *parentRegion)
2001+
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2002+
!domInfo.dominates(parentBlock, &block))
2003+
return WalkResult::interrupt();
2004+
}
20222005

20232006
// Don't capture this op if it has a not-allowed sibling, and stop recursing
20242007
// into nested operations.
20252008
for (Operation &sibling : op->getParentRegion()->getOps())
2026-
if (&sibling != op && !siblingAllowedInCapture(&sibling))
2009+
if (&sibling != op && !siblingAllowedFn(&sibling))
20272010
return WalkResult::interrupt();
20282011

20292012
// Don't continue capturing nested operations if we reach an omp.loop_nest.
@@ -2036,10 +2019,35 @@ Operation *TargetOp::getInnermostCapturedOmpOp() {
20362019
return capturedOp;
20372020
}
20382021

2039-
llvm::omp::OMPTgtExecModeFlags
2040-
TargetOp::getKernelExecFlags(Operation *capturedOp) {
2041-
using namespace llvm::omp;
2022+
Operation *TargetOp::getInnermostCapturedOmpOp() {
2023+
auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2024+
2025+
// Only allow OpenMP terminators and non-OpenMP ops that have known memory
2026+
// effects, but don't include a memory write effect.
2027+
return findCapturedOmpOp(
2028+
*this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2029+
if (!sibling)
2030+
return false;
2031+
2032+
if (ompDialect == sibling->getDialect())
2033+
return sibling->hasTrait<OpTrait::IsTerminator>();
2034+
2035+
if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2036+
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4>
2037+
effects;
2038+
memOp.getEffects(effects);
2039+
return !llvm::any_of(
2040+
effects, [&](MemoryEffects::EffectInstance &effect) {
2041+
return isa<MemoryEffects::Write>(effect.getEffect()) &&
2042+
isa<SideEffects::AutomaticAllocationScopeResource>(
2043+
effect.getResource());
2044+
});
2045+
}
2046+
return true;
2047+
});
2048+
}
20422049

2050+
TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
20432051
// A non-null captured op is only valid if it resides inside of a TargetOp
20442052
// and is the result of calling getInnermostCapturedOmpOp() on it.
20452053
TargetOp targetOp =
@@ -2048,60 +2056,102 @@ TargetOp::getKernelExecFlags(Operation *capturedOp) {
20482056
(targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
20492057
"unexpected captured op");
20502058

2051-
// Make sure this region is capturing a loop. Otherwise, it's a generic
2052-
// kernel.
2059+
// If it's not capturing a loop, it's a default target region.
20532060
if (!isa_and_present<LoopNestOp>(capturedOp))
2054-
return OMP_TGT_EXEC_MODE_GENERIC;
2061+
return TargetRegionFlags::generic;
20552062

2056-
SmallVector<LoopWrapperInterface> wrappers;
2057-
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
2058-
assert(!wrappers.empty());
2063+
auto getInnermostWrapper = [](LoopNestOp loopOp, int &numWrappers) {
2064+
SmallVector<LoopWrapperInterface> wrappers;
2065+
loopOp.gatherWrappers(wrappers);
2066+
assert(!wrappers.empty());
20592067

2060-
// Ignore optional SIMD leaf construct.
2061-
auto *innermostWrapper = wrappers.begin();
2062-
if (isa<SimdOp>(innermostWrapper))
2063-
innermostWrapper = std::next(innermostWrapper);
2068+
// Ignore optional SIMD leaf construct.
2069+
auto *wrapper = wrappers.begin();
2070+
if (isa<SimdOp>(wrapper))
2071+
wrapper = std::next(wrapper);
20642072

2065-
long numWrappers = std::distance(innermostWrapper, wrappers.end());
2073+
numWrappers = static_cast<int>(std::distance(wrapper, wrappers.end()));
2074+
return wrapper;
2075+
};
20662076

2067-
// Detect Generic-SPMD: target-teams-distribute[-simd].
2068-
// Detect SPMD: target-teams-loop.
2069-
if (numWrappers == 1) {
2070-
if (!isa<DistributeOp, LoopOp>(innermostWrapper))
2071-
return OMP_TGT_EXEC_MODE_GENERIC;
2077+
int numWrappers;
2078+
LoopWrapperInterface *innermostWrapper =
2079+
getInnermostWrapper(cast<LoopNestOp>(capturedOp), numWrappers);
20722080

2073-
Operation *teamsOp = (*innermostWrapper)->getParentOp();
2074-
if (!isa_and_present<TeamsOp>(teamsOp))
2075-
return OMP_TGT_EXEC_MODE_GENERIC;
2081+
if (numWrappers != 1 && numWrappers != 2)
2082+
return TargetRegionFlags::generic;
20762083

2077-
if (teamsOp->getParentOp() == targetOp.getOperation())
2078-
return isa<DistributeOp>(innermostWrapper)
2079-
? OMP_TGT_EXEC_MODE_GENERIC_SPMD
2080-
: OMP_TGT_EXEC_MODE_SPMD;
2081-
}
2082-
2083-
// Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
2084+
// Detect target-teams-distribute-parallel-wsloop[-simd].
20842085
if (numWrappers == 2) {
20852086
if (!isa<WsloopOp>(innermostWrapper))
2086-
return OMP_TGT_EXEC_MODE_GENERIC;
2087+
return TargetRegionFlags::generic;
20872088

20882089
innermostWrapper = std::next(innermostWrapper);
20892090
if (!isa<DistributeOp>(innermostWrapper))
2090-
return OMP_TGT_EXEC_MODE_GENERIC;
2091+
return TargetRegionFlags::generic;
20912092

20922093
Operation *parallelOp = (*innermostWrapper)->getParentOp();
20932094
if (!isa_and_present<ParallelOp>(parallelOp))
2094-
return OMP_TGT_EXEC_MODE_GENERIC;
2095+
return TargetRegionFlags::generic;
20952096

20962097
Operation *teamsOp = parallelOp->getParentOp();
20972098
if (!isa_and_present<TeamsOp>(teamsOp))
2098-
return OMP_TGT_EXEC_MODE_GENERIC;
2099+
return TargetRegionFlags::generic;
20992100

21002101
if (teamsOp->getParentOp() == targetOp.getOperation())
2101-
return OMP_TGT_EXEC_MODE_SPMD;
2102+
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2103+
}
2104+
// Detect target-teams-distribute[-simd] and target-teams-loop.
2105+
else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2106+
Operation *teamsOp = (*innermostWrapper)->getParentOp();
2107+
if (!isa_and_present<TeamsOp>(teamsOp))
2108+
return TargetRegionFlags::generic;
2109+
2110+
if (teamsOp->getParentOp() != targetOp.getOperation())
2111+
return TargetRegionFlags::generic;
2112+
2113+
if (isa<LoopOp>(innermostWrapper))
2114+
return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2115+
2116+
// Find single immediately nested captured omp.parallel and add spmd flag
2117+
// (generic-spmd case).
2118+
//
2119+
// TODO: This shouldn't have to be done here, as it is too easy to break.
2120+
// The openmp-opt pass should be updated to be able to promote kernels like
2121+
// this from "Generic" to "Generic-SPMD". However, the use of the
2122+
// `kmpc_distribute_static_loop` family of functions produced by the
2123+
// OMPIRBuilder for these kernels prevents that from working.
2124+
Dialect *ompDialect = targetOp->getDialect();
2125+
Operation *nestedCapture = findCapturedOmpOp(
2126+
capturedOp, /*checkSingleMandatoryExec=*/false,
2127+
[&](Operation *sibling) {
2128+
return sibling && (ompDialect != sibling->getDialect() ||
2129+
sibling->hasTrait<OpTrait::IsTerminator>());
2130+
});
2131+
2132+
TargetRegionFlags result =
2133+
TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2134+
2135+
if (!nestedCapture)
2136+
return result;
2137+
2138+
while (nestedCapture->getParentOp() != capturedOp)
2139+
nestedCapture = nestedCapture->getParentOp();
2140+
2141+
return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2142+
: result;
2143+
}
2144+
// Detect target-parallel-wsloop[-simd].
2145+
else if (isa<WsloopOp>(innermostWrapper)) {
2146+
Operation *parallelOp = (*innermostWrapper)->getParentOp();
2147+
if (!isa_and_present<ParallelOp>(parallelOp))
2148+
return TargetRegionFlags::generic;
2149+
2150+
if (parallelOp->getParentOp() == targetOp.getOperation())
2151+
return TargetRegionFlags::spmd;
21022152
}
21032153

2104-
return OMP_TGT_EXEC_MODE_GENERIC;
2154+
return TargetRegionFlags::generic;
21052155
}
21062156

21072157
//===----------------------------------------------------------------------===//

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5067,7 +5067,17 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
50675067
}
50685068

50695069
// Update kernel bounds structure for the `OpenMPIRBuilder` to use.
5070-
attrs.ExecFlags = targetOp.getKernelExecFlags(capturedOp);
5070+
omp::TargetRegionFlags kernelFlags = targetOp.getKernelExecFlags(capturedOp);
5071+
assert(
5072+
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic |
5073+
omp::TargetRegionFlags::spmd) &&
5074+
"invalid kernel flags");
5075+
attrs.ExecFlags =
5076+
omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::generic)
5077+
? omp::bitEnumContainsAny(kernelFlags, omp::TargetRegionFlags::spmd)
5078+
? llvm::omp::OMP_TGT_EXEC_MODE_GENERIC_SPMD
5079+
: llvm::omp::OMP_TGT_EXEC_MODE_GENERIC
5080+
: llvm::omp::OMP_TGT_EXEC_MODE_SPMD;
50715081
attrs.MinTeams = minTeamsVal;
50725082
attrs.MaxTeams.front() = maxTeamsVal;
50735083
attrs.MinThreads = 1;
@@ -5115,8 +5125,8 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
51155125
if (numThreads)
51165126
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
51175127

5118-
if (targetOp.getKernelExecFlags(capturedOp) !=
5119-
llvm::omp::OMP_TGT_EXEC_MODE_GENERIC) {
5128+
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
5129+
omp::TargetRegionFlags::trip_count)) {
51205130
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
51215131
attrs.LoopTripCount = nullptr;
51225132

mlir/test/Dialect/OpenMP/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,7 +2314,7 @@ func.func @omp_target_host_eval_parallel(%x : i32) {
23142314
// -----
23152315

23162316
func.func @omp_target_host_eval_loop1(%x : i32) {
2317-
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
2317+
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
23182318
omp.target host_eval(%x -> %arg0 : i32) {
23192319
omp.wsloop {
23202320
omp.loop_nest (%iv) : i32 = (%arg0) to (%arg0) step (%arg0) {
@@ -2329,7 +2329,7 @@ func.func @omp_target_host_eval_loop1(%x : i32) {
23292329
// -----
23302330

23312331
func.func @omp_target_host_eval_loop2(%x : i32) {
2332-
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when representing target SPMD or Generic-SPMD}}
2332+
// expected-error @below {{op host_eval argument only legal as loop bounds and steps in 'omp.loop_nest' when trip count must be evaluated in the host}}
23332333
omp.target host_eval(%x -> %arg0 : i32) {
23342334
omp.teams {
23352335
^bb0:

0 commit comments

Comments
 (0)