Skip to content

[MLIR][OpenMP] Add host_eval clause to omp.target #116049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 57 additions & 1 deletion mlir/docs/Dialects/OpenMPDialect/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,8 @@ introduction of private copies of the same underlying variable defined outside
the MLIR operation the clause is attached to. Currently, clauses with this
property can be classified into three main categories:
- Map-like clauses: `host_eval` (compiler internal, not defined by the OpenMP
specification), `map`, `use_device_addr` and `use_device_ptr`.
specification: [see more](#host-evaluated-clauses-in-target-regions)), `map`,
`use_device_addr` and `use_device_ptr`.
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
- Privatization clauses: `private`.

Expand Down Expand Up @@ -523,3 +524,58 @@ omp.parallel ... {
omp.terminator
} {omp.composite}
```

## Host-Evaluated Clauses in Target Regions

The `omp.target` operation, which represents the OpenMP `target` construct, is
marked with the `IsolatedFromAbove` trait. This means that, inside of its
region, no MLIR values defined outside of the op itself can be used. This is
consistent with the OpenMP specification of the `target` construct, which
mandates that all host device values used inside of the `target` region must
either be privatized (data-sharing) or mapped (data-mapping).

Normally, clauses applied to a construct are evaluated before entering that
construct. Further, in some cases, the OpenMP specification stipulates that
clauses be evaluated _on the host device_ on entry to a parent `target`
construct. In particular, the `num_teams` and `thread_limit` clauses of the
`teams` construct must be evaluated on the host device if it's nested inside or
combined with a `target` construct.

Additionally, the runtime library targeted by the MLIR to LLVM IR translation of
the OpenMP dialect supports the optimized launch of SPMD kernels (i.e.
`target teams distribute parallel {do,for}` in OpenMP), which requires
specifying in advance what the total trip count of the loop is. Consequently, it
is also beneficial to evaluate the trip count on the host device prior to the
kernel launch.

These host-evaluated values in MLIR would need to be placed outside of the
`omp.target` region and also attached to the corresponding nested operations,
which is not possible because of the `IsolatedFromAbove` trait. The solution
implemented to address this problem has been to introduce the `host_eval`
argument to the `omp.target` operation. It works similarly to a `map` clause,
but its only intended use is to forward host-evaluated values to their
corresponding operation inside of the region. Any uses outside of the previously
described result in a verifier error.

```mlir
// Initialize %0, %1, %2, %3...
omp.target host_eval(%0 -> %nt, %1 -> %lb, %2 -> %ub, %3 -> %step : i32, i32, i32, i32) {
omp.teams num_teams(to %nt : i32) {
omp.parallel {
omp.distribute {
omp.wsloop {
omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) {
// ...
omp.yield
}
omp.terminator
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
} {omp.composite}
omp.terminator
}
omp.terminator
}
```
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.h.inc"
Expand Down
33 changes: 25 additions & 8 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1224,10 +1224,10 @@ def TargetOp : OpenMP_Op<"target", traits = [
], clauses = [
// TODO: Complete clause list (defaultmap, uses_allocators).
OpenMP_AllocateClause, OpenMP_BareClause, OpenMP_DependClause,
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_IfClause,
OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_DeviceClause, OpenMP_HasDeviceAddrClause, OpenMP_HostEvalClause,
OpenMP_IfClause, OpenMP_InReductionClause, OpenMP_IsDevicePtrClause,
OpenMP_MapClauseSkip<assemblyFormat = true>, OpenMP_NowaitClause,
OpenMP_PrivateClause, OpenMP_ThreadLimitClause,
OpenMP_PrivateClause, OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "target construct";
let description = [{
Expand Down Expand Up @@ -1269,17 +1269,34 @@ def TargetOp : OpenMP_Op<"target", traits = [

return getMapVars()[mapInfoOpIdx];
}

/// Returns the innermost OpenMP dialect operation captured by this target
/// construct. For an operation to be detected as captured, it must be
/// inside a (possibly multi-level) nest of OpenMP dialect operation's
/// regions where none of these levels contain other operations considered
/// not-allowed for these purposes (i.e. only terminator operations are
/// allowed from the OpenMP dialect, and other dialect's operations are
/// allowed as long as they don't have a memory write effect).
///
/// If there are omp.loop_nest operations in the sequence of nested
/// operations, the top level one will be the one captured.
Operation *getInnermostCapturedOmpOp();

/// Infers the kernel type (Generic, SPMD or Generic-SPMD) based on the
/// contents of the target region.
llvm::omp::OMPTgtExecModeFlags getKernelExecFlags();
}] # clausesExtraClassDeclaration;

let assemblyFormat = clausesAssemblyFormat # [{
custom<InReductionMapPrivateRegion>(
$region, $in_reduction_vars, type($in_reduction_vars),
$in_reduction_byref, $in_reduction_syms, $map_vars, type($map_vars),
$private_vars, type($private_vars), $private_syms, $private_maps)
attr-dict
custom<HostEvalInReductionMapPrivateRegion>(
$region, $host_eval_vars, type($host_eval_vars), $in_reduction_vars,
type($in_reduction_vars), $in_reduction_byref, $in_reduction_syms,
$map_vars, type($map_vars), $private_vars, type($private_vars),
$private_syms, $private_maps) attr-dict
}];

let hasVerifier = 1;
let hasRegionVerifier = 1;
}


Expand Down
206 changes: 198 additions & 8 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include <cstddef>
#include <iterator>
#include <optional>
Expand Down Expand Up @@ -691,8 +692,10 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
return parser.parseRegion(region, entryBlockArgs);
}

static ParseResult parseInReductionMapPrivateRegion(
static ParseResult parseHostEvalInReductionMapPrivateRegion(
OpAsmParser &parser, Region &region,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &hostEvalVars,
SmallVectorImpl<Type> &hostEvalTypes,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &inReductionVars,
SmallVectorImpl<Type> &inReductionTypes,
DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
Expand All @@ -702,6 +705,7 @@ static ParseResult parseInReductionMapPrivateRegion(
llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
DenseI64ArrayAttr &privateMaps) {
AllRegionParseArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
Expand Down Expand Up @@ -931,13 +935,15 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
p.printRegion(region, /*printEntryBlockArgs=*/false);
}

static void printInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
static void printHostEvalInReductionMapPrivateRegion(
OpAsmPrinter &p, Operation *op, Region &region, ValueRange hostEvalVars,
TypeRange hostEvalTypes, ValueRange inReductionVars,
TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes,
ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms,
DenseI64ArrayAttr privateMaps) {
AllRegionPrintArgs args;
args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
inReductionByref, inReductionSyms);
args.mapArgs.emplace(mapVars, mapTypes);
Expand Down Expand Up @@ -1720,11 +1726,12 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
clauses.ifExpr, /*in_reduction_vars=*/{},
/*in_reduction_byref=*/nullptr, /*in_reduction_syms=*/nullptr,
clauses.isDevicePtrVars, clauses.mapVars, clauses.nowait,
clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
clauses.threadLimit, /*private_maps=*/nullptr);
clauses.hostEvalVars, clauses.ifExpr,
/*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.threadLimit,
/*private_maps=*/nullptr);
}

LogicalResult TargetOp::verify() {
Expand All @@ -1742,6 +1749,189 @@ LogicalResult TargetOp::verify() {
return verifyPrivateVarsMapping(*this);
}

LogicalResult TargetOp::verifyRegions() {
auto teamsOps = getOps<TeamsOp>();
if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
return emitError("target containing multiple 'omp.teams' nested ops");

// Check that host_eval values are only used in legal ways.
llvm::omp::OMPTgtExecModeFlags execFlags = getKernelExecFlags();
for (Value hostEvalArg :
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
if (llvm::is_contained({teamsOp.getNumTeamsLower(),
teamsOp.getNumTeamsUpper(),
teamsOp.getThreadLimit()},
hostEvalArg))
continue;

return emitOpError() << "host_eval argument only legal as 'num_teams' "
"and 'thread_limit' in 'omp.teams'";
}
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (execFlags == llvm::omp::OMP_TGT_EXEC_MODE_SPMD &&
hostEvalArg == parallelOp.getNumThreads())
continue;

return emitOpError()
<< "host_eval argument only legal as 'num_threads' in "
"'omp.parallel' when representing target SPMD";
}
if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
if (execFlags != llvm::omp::OMP_TGT_EXEC_MODE_GENERIC &&
(llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
continue;

return emitOpError() << "host_eval argument only legal as loop bounds "
"and steps in 'omp.loop_nest' when "
"representing target SPMD or Generic-SPMD";
}

return emitOpError() << "host_eval argument illegal use in '"
<< user->getName() << "' operation";
}
}
return success();
}

/// Only allow OpenMP terminators and non-OpenMP ops that have known memory
/// effects, but don't include a memory write effect.
static bool siblingAllowedInCapture(Operation *op) {
if (!op)
return false;

bool isOmpDialect =
op->getContext()->getLoadedDialect<omp::OpenMPDialect>() ==
op->getDialect();

if (isOmpDialect)
return op->hasTrait<OpTrait::IsTerminator>();

if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>, 4> effects;
memOp.getEffects(effects);
return !llvm::any_of(effects, [&](MemoryEffects::EffectInstance &effect) {
return isa<MemoryEffects::Write>(effect.getEffect()) &&
isa<SideEffects::AutomaticAllocationScopeResource>(
effect.getResource());
});
}
return true;
}

Operation *TargetOp::getInnermostCapturedOmpOp() {
Dialect *ompDialect = (*this)->getDialect();
Operation *capturedOp = nullptr;
DominanceInfo domInfo;

// Process in pre-order to check operations from outermost to innermost,
// ensuring we only enter the region of an operation if it meets the criteria
// for being captured. We stop the exploration of nested operations as soon as
// we process a region holding no operations to be captured.
walk<WalkOrder::PreOrder>([&](Operation *op) {
if (op == *this)
return WalkResult::advance();

// Ignore operations of other dialects or omp operations with no regions,
// because these will only be checked if they are siblings of an omp
// operation that can potentially be captured.
bool isOmpDialect = op->getDialect() == ompDialect;
bool hasRegions = op->getNumRegions() > 0;
if (!isOmpDialect || !hasRegions)
return WalkResult::skip();

// This operation cannot be captured if it can be executed more than once
// (i.e. its block's successors can reach it) or if it's not guaranteed to
// be executed before all exits of the region (i.e. it doesn't dominate all
// blocks with no successors reachable from the entry block).
Region *parentRegion = op->getParentRegion();
Block *parentBlock = op->getBlock();

for (Block *successor : parentBlock->getSuccessors())
if (successor->isReachable(parentBlock))
return WalkResult::interrupt();

for (Block &block : *parentRegion)
if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
!domInfo.dominates(parentBlock, &block))
return WalkResult::interrupt();

// Don't capture this op if it has a not-allowed sibling, and stop recursing
// into nested operations.
for (Operation &sibling : op->getParentRegion()->getOps())
if (&sibling != op && !siblingAllowedInCapture(&sibling))
return WalkResult::interrupt();

// Don't continue capturing nested operations if we reach an omp.loop_nest.
// Otherwise, process the contents of this operation.
capturedOp = op;
return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
: WalkResult::advance();
});

return capturedOp;
}

llvm::omp::OMPTgtExecModeFlags TargetOp::getKernelExecFlags() {
using namespace llvm::omp;

// Make sure this region is capturing a loop. Otherwise, it's a generic
// kernel.
Operation *capturedOp = getInnermostCapturedOmpOp();
if (!isa_and_present<LoopNestOp>(capturedOp))
return OMP_TGT_EXEC_MODE_GENERIC;

SmallVector<LoopWrapperInterface> wrappers;
cast<LoopNestOp>(capturedOp).gatherWrappers(wrappers);
assert(!wrappers.empty());

// Ignore optional SIMD leaf construct.
auto *innermostWrapper = wrappers.begin();
if (isa<SimdOp>(innermostWrapper))
innermostWrapper = std::next(innermostWrapper);

long numWrappers = std::distance(innermostWrapper, wrappers.end());

// Detect Generic-SPMD: target-teams-distribute[-simd].
if (numWrappers == 1) {
if (!isa<DistributeOp>(innermostWrapper))
return OMP_TGT_EXEC_MODE_GENERIC;

Operation *teamsOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return OMP_TGT_EXEC_MODE_GENERIC;

if (teamsOp->getParentOp() == *this)
return OMP_TGT_EXEC_MODE_GENERIC_SPMD;
}

// Detect SPMD: target-teams-distribute-parallel-wsloop[-simd].
if (numWrappers == 2) {
if (!isa<WsloopOp>(innermostWrapper))
return OMP_TGT_EXEC_MODE_GENERIC;

innermostWrapper = std::next(innermostWrapper);
if (!isa<DistributeOp>(innermostWrapper))
return OMP_TGT_EXEC_MODE_GENERIC;

Operation *parallelOp = (*innermostWrapper)->getParentOp();
if (!isa_and_present<ParallelOp>(parallelOp))
return OMP_TGT_EXEC_MODE_GENERIC;

Operation *teamsOp = parallelOp->getParentOp();
if (!isa_and_present<TeamsOp>(teamsOp))
return OMP_TGT_EXEC_MODE_GENERIC;

if (teamsOp->getParentOp() == *this)
return OMP_TGT_EXEC_MODE_SPMD;
}

return OMP_TGT_EXEC_MODE_GENERIC;
}

//===----------------------------------------------------------------------===//
// ParallelOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading