Skip to content

Commit 5efde4c

Browse files
committed
[MLIR][OpenMP] Add the host_eval clause
This patch adds the definition of a new entry block argument-defining `host_eval` clause. This is intended to implement the passthrough approach discussed in [this RFC](https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106), for supporting host-evaluated clauses that apply to operations nested inside of `omp.target`.
1 parent 03e7862 commit 5efde4c

File tree

4 files changed

+74
-7
lines changed

4 files changed

+74
-7
lines changed

mlir/docs/Dialects/OpenMPDialect/_index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
297297
introduction of private copies of the same underlying variable defined outside
298298
the MLIR operation the clause is attached to. Currently, clauses with this
299299
property can be classified into three main categories:
300-
- Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
300+
- Map-like clauses: `host_eval` (compiler internal, not defined by the OpenMP
301+
specification), `map`, `use_device_addr` and `use_device_ptr`.
301302
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
302303
- Privatization clauses: `private`.
303304

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,44 @@ class OpenMP_HintClauseSkip<
470470

471471
def OpenMP_HintClause : OpenMP_HintClauseSkip<>;
472472

473+
//===----------------------------------------------------------------------===//
474+
// Not in the spec: Clause-like structure to hold host-evaluated values.
475+
//===----------------------------------------------------------------------===//
476+
477+
class OpenMP_HostEvalClauseSkip<
478+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
479+
bit description = false, bit extraClassDeclaration = false
480+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
481+
extraClassDeclaration> {
482+
let traits = [
483+
BlockArgOpenMPOpInterface, IsolatedFromAbove
484+
];
485+
486+
let arguments = (ins
487+
Variadic<AnyType>:$host_eval_vars
488+
);
489+
490+
let extraClassDeclaration = [{
491+
unsigned numHostEvalBlockArgs() {
492+
return getHostEvalVars().size();
493+
}
494+
}];
495+
496+
let description = [{
497+
The optional `host_eval_vars` holds values defined outside of the region of
498+
the `IsolatedFromAbove` operation for which a corresponding entry block
499+
argument is defined. The only legal uses for these captured values are the
500+
following:
501+
- `num_teams` or `thread_limit` clause of an immediately nested
502+
`omp.teams` operation.
503+
- If the operation is the top-level `omp.target` of a target SPMD kernel:
504+
- `num_threads` clause of the nested `omp.parallel` operation.
505+
- Bounds and steps of the nested `omp.loop_nest` operation.
506+
}];
507+
}
508+
509+
def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;
510+
473511
//===----------------------------------------------------------------------===//
474512
// V5.2: [3.4] `if` clause
475513
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
2525

2626
let methods = [
2727
// Default-implemented methods to be overriden by the corresponding clauses.
28+
InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
29+
"unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
30+
return 0;
31+
}]>,
2832
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
2933
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
3034
return 0;
@@ -54,10 +58,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
5458
return 0;
5559
}]>,
5660

57-
// Unified access methods for clause-associated entry block arguments.
61+
// Unified access methods for start indices of clause-associated entry block
62+
// arguments.
63+
InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
64+
"unsigned", "getHostEvalBlockArgsStart", (ins), [{
65+
return 0;
66+
}]>,
5867
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
5968
"unsigned", "getInReductionBlockArgsStart", (ins), [{
60-
return 0;
69+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
70+
return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
6171
}]>,
6272
InterfaceMethod<"Get start index of block arguments defined by `map`.",
6373
"unsigned", "getMapBlockArgsStart", (ins), [{
@@ -91,6 +101,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
91101
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
92102
}]>,
93103

104+
// Unified access methods for clause-associated entry block arguments.
105+
InterfaceMethod<"Get block arguments defined by `host_eval`.",
106+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
107+
"getHostEvalBlockArgs", (ins), [{
108+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
109+
return $_op->getRegion(0).getArguments().slice(
110+
iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
111+
}]>,
94112
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
95113
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
96114
"getInReductionBlockArgs", (ins), [{
@@ -147,10 +165,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
147165

148166
let verify = [{
149167
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
150-
unsigned expectedArgs = iface.numInReductionBlockArgs() +
151-
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
152-
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
153-
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
168+
unsigned expectedArgs = iface.numHostEvalBlockArgs() +
169+
iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
170+
iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
171+
iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
172+
iface.numUseDevicePtrBlockArgs();
154173
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
155174
return $_op->emitOpError() << "expected at least " << expectedArgs
156175
<< " entry block argument(s)";

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ struct ReductionParseArgs {
504504
: vars(vars), types(types), byref(byref), syms(syms) {}
505505
};
506506
struct AllRegionParseArgs {
507+
std::optional<MapParseArgs> hostEvalArgs;
507508
std::optional<ReductionParseArgs> inReductionArgs;
508509
std::optional<MapParseArgs> mapArgs;
509510
std::optional<PrivateParseArgs> privateArgs;
@@ -647,6 +648,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
647648
AllRegionParseArgs args) {
648649
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
649650

651+
if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
652+
args.hostEvalArgs)))
653+
return parser.emitError(parser.getCurrentLocation())
654+
<< "invalid `host_eval` format";
655+
650656
if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
651657
args.inReductionArgs)))
652658
return parser.emitError(parser.getCurrentLocation())
@@ -812,6 +818,7 @@ struct ReductionPrintArgs {
812818
: vars(vars), types(types), byref(byref), syms(syms) {}
813819
};
814820
struct AllRegionPrintArgs {
821+
std::optional<MapPrintArgs> hostEvalArgs;
815822
std::optional<ReductionPrintArgs> inReductionArgs;
816823
std::optional<MapPrintArgs> mapArgs;
817824
std::optional<PrivatePrintArgs> privateArgs;
@@ -902,6 +909,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
902909
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
903910
MLIRContext *ctx = op->getContext();
904911

912+
printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
913+
args.hostEvalArgs);
905914
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
906915
args.inReductionArgs);
907916
printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),

0 commit comments

Comments
 (0)