Skip to content

Commit 0a47845

Browse files
committed
[MLIR][OpenMP] Add the host_eval clause
This patch adds the definition of a new entry block argument-defining `host_eval` clause. This is intended to implement the passthrough approach discussed in [this RFC](https://discourse.llvm.org/t/rfc-openmp-dialect-representation-of-num-teams-thread-limit-and-target-spmd/81106), for supporting host-evaluated clauses that apply to operations nested inside of `omp.target`.
1 parent 76befc8 commit 0a47845

File tree

4 files changed

+74
-7
lines changed

4 files changed

+74
-7
lines changed

mlir/docs/Dialects/OpenMPDialect/_index.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,8 @@ arguments for the region of that MLIR operation. This enables, for example, the
297297
introduction of private copies of the same underlying variable defined outside
298298
the MLIR operation the clause is attached to. Currently, clauses with this
299299
property can be classified into three main categories:
300-
- Map-like clauses: `map`, `use_device_addr` and `use_device_ptr`.
300+
- Map-like clauses: `host_eval`, `map`, `use_device_addr` and
301+
`use_device_ptr`.
301302
- Reduction-like clauses: `in_reduction`, `reduction` and `task_reduction`.
302303
- Privatization clauses: `private`.
303304

mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -444,6 +444,44 @@ class OpenMP_HintClauseSkip<
444444

445445
def OpenMP_HintClause : OpenMP_HintClauseSkip<>;
446446

447+
//===----------------------------------------------------------------------===//
448+
// Not in the spec: Clause-like structure to hold host-evaluated values.
449+
//===----------------------------------------------------------------------===//
450+
451+
class OpenMP_HostEvalClauseSkip<
452+
bit traits = false, bit arguments = false, bit assemblyFormat = false,
453+
bit description = false, bit extraClassDeclaration = false
454+
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
455+
extraClassDeclaration> {
456+
let traits = [
457+
BlockArgOpenMPOpInterface
458+
];
459+
460+
let arguments = (ins
461+
Variadic<AnyType>:$host_eval_vars
462+
);
463+
464+
let extraClassDeclaration = [{
465+
unsigned numHostEvalBlockArgs() {
466+
return getHostEvalVars().size();
467+
}
468+
}];
469+
470+
let description = [{
471+
The optional `host_eval_vars` holds values defined outside of the region of
472+
the `IsolatedFromAbove` operation for which a corresponding entry block
473+
argument is defined. The only legal uses for these captured values are the
474+
following:
475+
- `num_teams` or `thread_limit` clause of an immediately nested
476+
`omp.teams` operation.
477+
- If the operation is the top-level `omp.target` of a target SPMD kernel:
478+
- `num_threads` clause of the nested `omp.parallel` operation.
479+
- Bounds and steps of the nested `omp.loop_nest` operation.
480+
}];
481+
}
482+
483+
def OpenMP_HostEvalClause : OpenMP_HostEvalClauseSkip<>;
484+
447485
//===----------------------------------------------------------------------===//
448486
// V5.2: [3.4] `if` clause
449487
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
2525

2626
let methods = [
2727
// Default-implemented methods to be overriden by the corresponding clauses.
28+
InterfaceMethod<"Get number of block arguments defined by `host_eval`.",
29+
"unsigned", "numHostEvalBlockArgs", (ins), [{}], [{
30+
return 0;
31+
}]>,
2832
InterfaceMethod<"Get number of block arguments defined by `in_reduction`.",
2933
"unsigned", "numInReductionBlockArgs", (ins), [{}], [{
3034
return 0;
@@ -54,10 +58,16 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
5458
return 0;
5559
}]>,
5660

57-
// Unified access methods for clause-associated entry block arguments.
61+
// Unified access methods for start indices of clause-associated entry block
62+
// arguments.
63+
InterfaceMethod<"Get start index of block arguments defined by `host_eval`.",
64+
"unsigned", "getHostEvalBlockArgsStart", (ins), [{
65+
return 0;
66+
}]>,
5867
InterfaceMethod<"Get start index of block arguments defined by `in_reduction`.",
5968
"unsigned", "getInReductionBlockArgsStart", (ins), [{
60-
return 0;
69+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
70+
return iface.getHostEvalBlockArgsStart() + $_op.numHostEvalBlockArgs();
6171
}]>,
6272
InterfaceMethod<"Get start index of block arguments defined by `map`.",
6373
"unsigned", "getMapBlockArgsStart", (ins), [{
@@ -91,6 +101,14 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
91101
return iface.getUseDeviceAddrBlockArgsStart() + $_op.numUseDeviceAddrBlockArgs();
92102
}]>,
93103

104+
// Unified access methods for clause-associated entry block arguments.
105+
InterfaceMethod<"Get block arguments defined by `host_eval`.",
106+
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
107+
"getHostEvalBlockArgs", (ins), [{
108+
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
109+
return $_op->getRegion(0).getArguments().slice(
110+
iface.getHostEvalBlockArgsStart(), $_op.numHostEvalBlockArgs());
111+
}]>,
94112
InterfaceMethod<"Get block arguments defined by `in_reduction`.",
95113
"::llvm::MutableArrayRef<::mlir::BlockArgument>",
96114
"getInReductionBlockArgs", (ins), [{
@@ -147,10 +165,11 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
147165

148166
let verify = [{
149167
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
150-
unsigned expectedArgs = iface.numInReductionBlockArgs() +
151-
iface.numMapBlockArgs() + iface.numPrivateBlockArgs() +
152-
iface.numReductionBlockArgs() + iface.numTaskReductionBlockArgs() +
153-
iface.numUseDeviceAddrBlockArgs() + iface.numUseDevicePtrBlockArgs();
168+
unsigned expectedArgs = iface.numHostEvalBlockArgs() +
169+
iface.numInReductionBlockArgs() + iface.numMapBlockArgs() +
170+
iface.numPrivateBlockArgs() + iface.numReductionBlockArgs() +
171+
iface.numTaskReductionBlockArgs() + iface.numUseDeviceAddrBlockArgs() +
172+
iface.numUseDevicePtrBlockArgs();
154173
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
155174
return $_op->emitOpError() << "expected at least " << expectedArgs
156175
<< " entry block argument(s)";

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ struct ReductionParseArgs {
502502
: vars(vars), types(types), byref(byref), syms(syms) {}
503503
};
504504
struct AllRegionParseArgs {
505+
std::optional<MapParseArgs> hostEvalArgs;
505506
std::optional<ReductionParseArgs> inReductionArgs;
506507
std::optional<MapParseArgs> mapArgs;
507508
std::optional<PrivateParseArgs> privateArgs;
@@ -628,6 +629,11 @@ static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
628629
AllRegionParseArgs args) {
629630
llvm::SmallVector<OpAsmParser::Argument> entryBlockArgs;
630631

632+
if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
633+
args.hostEvalArgs)))
634+
return parser.emitError(parser.getCurrentLocation())
635+
<< "invalid `host_eval` format";
636+
631637
if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
632638
args.inReductionArgs)))
633639
return parser.emitError(parser.getCurrentLocation())
@@ -789,6 +795,7 @@ struct ReductionPrintArgs {
789795
: vars(vars), types(types), byref(byref), syms(syms) {}
790796
};
791797
struct AllRegionPrintArgs {
798+
std::optional<MapPrintArgs> hostEvalArgs;
792799
std::optional<ReductionPrintArgs> inReductionArgs;
793800
std::optional<MapPrintArgs> mapArgs;
794801
std::optional<PrivatePrintArgs> privateArgs;
@@ -867,6 +874,8 @@ static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region,
867874
auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
868875
MLIRContext *ctx = op->getContext();
869876

877+
printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
878+
args.hostEvalArgs);
870879
printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
871880
args.inReductionArgs);
872881
printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),

0 commit comments

Comments
 (0)