Skip to content

Commit 270790b

Browse files
committed
[MLIR][OpenMP] Minor improvements to BlockArgOpenMPOpInterface, NFC
This patch introduces a use for the new `getBlockArgsPairs` to avoid having to manually list each applicable clause. Also, the `numClauseBlockArgs()` function is introduced, which simplifies the implementation of the interface's verifier and enables better memory handling within `getBlockArgsPairs`.
1 parent f8b03e2 commit 270790b

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

mlir/docs/Dialects/OpenMPDialect/_index.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,8 @@ accessed:
372372
should be located.
373373
- `get<ClauseName>BlockArgs()`: Returns the list of entry block arguments
374374
defined by the given clause.
375+
- `numClauseBlockArgs()`: Returns the total number of entry block arguments
376+
defined by all clauses.
375377
- `getBlockArgsPairs()`: Returns a list of pairs where the first element is
376378
the outside value, or operand, and the second element is the corresponding
377379
entry block argument.

mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,12 +136,20 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
136136
!foreach(clause, clauses, clause.startMethod),
137137
!foreach(clause, clauses, clause.blockArgsMethod),
138138
[
139+
InterfaceMethod<
140+
"Get the total number of clause-defined entry block arguments",
141+
"unsigned", "numClauseBlockArgs", (ins),
142+
"return " # !interleave(
143+
!foreach(clause, clauses, "$_op." # clause.numArgsMethod.name # "()"),
144+
" + ") # ";"
145+
>,
139146
InterfaceMethod<
140147
"Populate a vector of pairs representing the matching between operands "
141148
"and entry block arguments.", "void", "getBlockArgsPairs",
142149
(ins "::llvm::SmallVectorImpl<std::pair<::mlir::Value, ::mlir::BlockArgument>> &" : $pairs),
143150
[{
144151
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>(*$_op);
152+
pairs.reserve(pairs.size() + iface.numClauseBlockArgs());
145153
}] # !interleave(!foreach(clause, clauses, [{
146154
}] # "if (iface." # clause.numArgsMethod.name # "() > 0) {" # [{
147155
}] # " for (auto [var, arg] : ::llvm::zip_equal(" #
@@ -155,11 +163,7 @@ def BlockArgOpenMPOpInterface : OpInterface<"BlockArgOpenMPOpInterface"> {
155163

156164
let verify = [{
157165
auto iface = ::llvm::cast<BlockArgOpenMPOpInterface>($_op);
158-
}] # "unsigned expectedArgs = "
159-
# !interleave(
160-
!foreach(clause, clauses, "iface." # clause.numArgsMethod.name # "()"),
161-
" + "
162-
) # ";" # [{
166+
unsigned expectedArgs = iface.numClauseBlockArgs();
163167
if ($_op->getRegion(0).getNumArguments() < expectedArgs)
164168
return $_op->emitOpError() << "expected at least " << expectedArgs
165169
<< " entry block argument(s)";

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -550,18 +550,16 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
550550
// corresponding operand. This is semantically equivalent to this wrapper not
551551
// being present.
552552
auto forwardArgs =
553-
[&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
554-
OperandRange operands) {
555-
for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
553+
[&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
554+
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
555+
blockArgIface.getBlockArgsPairs(blockArgsPairs);
556+
for (auto [var, arg] : blockArgsPairs)
556557
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
557558
};
558559

559560
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
560561
.Case([&](omp::SimdOp op) {
561-
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
562-
forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
563-
forwardArgs(blockArgIface.getReductionBlockArgs(),
564-
op.getReductionVars());
562+
forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
565563
op.emitWarning() << "simd information on composite construct discarded";
566564
return success();
567565
})

0 commit comments

Comments
 (0)