Skip to content

[MLIR][OpenMP] Add the host_eval clause #116048

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/docs/Dialects/OpenMPDialect/_index.md
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
introduction of private copies of the same underlying variable defined outside
the MLIR operation the clause is attached to. Currently, clauses with this
property can be classified into three main categories:
- Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
- Map-like clauses: `host_eval` (compiler internal, not defined by the OpenMP
specification), `map`, `use_device_addr` and `use_device_ptr`.
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
- Privatization clauses: `private`.

Expand Down
38 changes: 38 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,44 @@ class OpenMP_HintClauseSkip<

def OpenMP_HintClause : OpenMP_HintClauseSkip<>;

//===----------------------------------------------------------------------===//
// Not in the spec: Clause-like structure to hold host-evaluated values.
//===----------------------------------------------------------------------===//

class OpenMP_HostEvalClauseSkip<
bit traits = false, bit arguments = false, bit assemblyFormat = false,
bit description = false, bit extraClassDeclaration = false
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let traits = [
BlockArgOpenMPOpInterface, IsolatedFromAbove
];

let arguments = (ins
Variadic<AnyType>:$host_eval_vars
);

let extraClassDeclaration = [{
unsigned numHostEvalBlockArgs() {
return getHostEvalVars().size();
}
}];

let description = [{
The optional `host_eval_vars` holds values defined outside of the region of
the `IsolatedFromAbove` operation for which a corresponding entry block
argument is defined. The only legal uses for these captured values are the
following:
- `num_teams` or `thread_limit` clause of an immediately nested
`omp.teams` operation.
- If the operation is the top-level `omp.target` of a target SPMD kernel:
- `num_threads` clause of the nested `omp.parallel` operation.
- Bounds and steps of the nested `omp.loop_nest` operation.
}];
}

def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;

//===----------------------------------------------------------------------===//
// V5.2: [3.4] `if` clause
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 25 additions & 6 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {

let methods = [
// Default-implemented methods to be overriden by the corresponding clauses.
InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
"unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
return 0;
}]>,
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
return 0;
Expand Down Expand Up @@ -54,10 +58,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return 0;
}]>,

// Unified access methods for clause-associated entry block arguments.
// Unified access methods for start indices of clause-associated entry block
// arguments.
InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
"unsigned", "getHostEvalBlockArgsStart", (ins), [{
return 0;
}]>,
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
"unsigned", "getInReductionBlockArgsStart", (ins), [{
return 0;
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
}]>,
InterfaceMethod<"Get start index of block arguments defined by `map`.",
"unsigned", "getMapBlockArgsStart", (ins), [{
Expand Down Expand Up @@ -91,6 +101,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
}]>,

// Unified access methods for clause-associated entry block arguments.
InterfaceMethod<"Get block arguments defined by `host_eval`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getHostEvalBlockArgs", (ins), [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
return $_op->getRegion(0).getArguments().slice(
iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
}]>,
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
"getInReductionBlockArgs", (ins), [{
Expand Down Expand Up @@ -147,10 +165,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {

let verify = [{
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
unsigned expectedArgs = iface.numInReductionBlockArgs() +
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
unsigned expectedArgs = iface.numHostEvalBlockArgs() +
iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
iface.numUseDevicePtrBlockArgs();
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
return $_op->emitOpError() << "expected at least " << expectedArgs
<< " entry block argument(s)";
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,7 @@ struct ReductionParseArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionParseArgs {
std::optional<MapParseArgs> hostEvalArgs;
std::optional<ReductionParseArgs> inReductionArgs;
std::optional<MapParseArgs> mapArgs;
std::optional<PrivateParseArgs> privateArgs;
Expand Down Expand Up @@ -647,6 +648,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
AllRegionParseArgs args) {
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;

if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
args.hostEvalArgs)))
return parser.emitError(parser.getCurrentLocation())
<< "invalid `host_eval` format";

if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
args.inReductionArgs)))
return parser.emitError(parser.getCurrentLocation())
Expand Down Expand Up @@ -812,6 +818,7 @@ struct ReductionPrintArgs {
: vars(vars), types(types), byref(byref), syms(syms) {}
};
struct AllRegionPrintArgs {
std::optional<MapPrintArgs> hostEvalArgs;
std::optional<ReductionPrintArgs> inReductionArgs;
std::optional<MapPrintArgs> mapArgs;
std::optional<PrivatePrintArgs> privateArgs;
Expand Down Expand Up @@ -902,6 +909,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
MLIRContext *ctx = op->getContext();

printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
args.hostEvalArgs);
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
args.inReductionArgs);
printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
Expand Down
Loading