Skip to content

Commit e4b04b3

Browse files
authored
[mlir] make transform.foreach_match forward arguments (#89920)
It may be useful to have access to additional handles or parameters when performing matches and actions in `foreach_match`, for example, to parameterize the matcher by rank or restrict it in a non-trivial way. Enable `foreach_match` to forward additional handles from operands to matcher symbols and from action symbols to results.
1 parent edbe6eb commit e4b04b3

File tree

6 files changed

+395
-69
lines changed

6 files changed

+395
-69
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,10 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
512512
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
513513
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
514514
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
515-
DeclareOpInterfaceMethods<TransformOpInterface>]> {
515+
DeclareOpInterfaceMethods<TransformOpInterface,
516+
["allowsRepeatedHandleOperands"]>,
517+
DeclareOpInterfaceMethods<OpAsmOpInterface,
518+
["getAsmResultNames"]>]> {
516519
let summary = "Applies named sequences when a named matcher succeeds";
517520
let description = [{
518521
Given a pair of co-indexed lists of transform dialect symbols (such as
@@ -528,25 +531,31 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
528531
the following matchers are not applied to the same payload operation. If the
529532
action succeeds, the next payload operation in walk order is matched. If it
530533
fails, both silenceable and definite errors are propagated as the result of
531-
this op.
532-
533-
The matcher symbol must take one operand of a type that implements the same
534-
transform dialect interface as the `root` operand (a check is performed at
535-
application time to see if the associated payload satisfies the constraints
536-
of the actual type). It must not consume the operand as multiple matchers
534+
this op; propagation of silenceable errors is postponed until the end of the
535+
walk.
536+
537+
The matcher symbol must take at least one operand of a type that implements
538+
the same transform dialect interface as the `root` operand (a check is
539+
performed at application time to see if the associated payload satisfies the
540+
constraints of the actual type), and may take additional operands with a
541+
similar type requirement. It must not consume operands as multiple matchers
537542
may be applied. The matcher may produce any number of results. The action
538543
symbol paired with the matcher must take the same number of arguments as the
539544
matcher has results, and these arguments must implement the same transform
540545
dialect interfaces, but not necessarily have the exact same type (again, a
541546
check is performed at application time to see if the associated payload
542-
satisfies the constraints of actual types on both sides). The action symbol
543-
may not have results. The actions are expected to only modify payload
544-
operations nested in the `root` payload operations associated with the
545-
operand of this transform operation. Furhermore, the actions may not modify
546-
operations outside of the currently matched payload operation, e.g., they
547-
may not modify sibling or parent operations. If such behavior is desired,
548-
the parent must be matched first and the nested operations obtained by
549-
traversing the IR from the parent. This is due to the matching being
547+
satisfies the constraints of actual types on both sides).
548+
549+
The action symbol may have results that are accumulated from all actions and
550+
returned from the `foreach_match` operation on success. Unless the
551+
`flatten_results` attribute is present, each action result must be
552+
associated with exactly one payload entity. The actions are expected to only
553+
modify payload operations nested in the `root` payload operations associated
554+
with the operand of this transform operation. Furthermore, the actions may
555+
not modify operations outside of the currently matched payload operation,
556+
e.g., they may not modify sibling or parent operations. If such behavior is
557+
desired, the parent must be matched first and the nested operations obtained
558+
by traversing the IR from the parent. This is due to the matching being
550559
performed as a post-order IR walk.
551560

552561
This operation consumes the operand and produces a new handle associated
@@ -573,19 +582,26 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
573582
produced a definite failure.
574583
}];
575584

576-
let arguments = (ins TransformHandleTypeInterface:$root,
577-
UnitAttr:$restrict_root,
578-
SymbolRefArrayAttr:$matchers,
579-
SymbolRefArrayAttr:$actions);
580-
let results = (outs TransformHandleTypeInterface:$updated);
585+
let arguments =
586+
(ins TransformHandleTypeInterface:$root,
587+
Variadic<Transform_AnyHandleOrParamType>:$forwarded_inputs,
588+
UnitAttr:$restrict_root,
589+
UnitAttr:$flatten_results,
590+
SymbolRefArrayAttr:$matchers,
591+
SymbolRefArrayAttr:$actions);
592+
let results =
593+
(outs TransformHandleTypeInterface:$updated,
594+
Variadic<Transform_AnyHandleOrParamType>:$forwarded_outputs);
581595

582596
let assemblyFormat = [{
583-
(`restrict_root` $restrict_root^)?
597+
oilist( `restrict_root` $restrict_root
598+
| `flatten_results` $flatten_results
599+
)
584600
`in`
585-
$root
601+
$root (`,` $forwarded_inputs^)?
586602
custom<ForeachMatchSymbols>($matchers, $actions)
587603
attr-dict
588-
`:` functional-type($root, $updated)
604+
`:` functional-type(operands, results)
589605
}];
590606

591607
let hasVerifier = 1;

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ void getPotentialTopLevelEffects(
5252
/// Verification hook for TransformOpInterface.
5353
LogicalResult verifyTransformOpInterface(Operation *op);
5454

55+
/// Appends the entities associated with the given transform values in `state`
56+
/// to the pre-existing list of mappings. The array of mappings must have as
57+
/// many elements as values. If `flatten` is set, multiple values may be
58+
/// associated with each transform value, and this always succeeds. Otherwise,
59+
/// checks that each value has exactly one mapping associated and return failure
60+
/// otherwise.
61+
LogicalResult appendValueMappings(
62+
MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
63+
ValueRange values, const transform::TransformState &state,
64+
bool flatten = true);
65+
5566
/// Populates `mappings` with mapped values associated with the given transform
5667
/// IR values in the given `state`.
5768
void prepareValueMappings(
@@ -317,6 +328,8 @@ class TransformState {
317328
}
318329
LogicalResult mapBlockArgument(BlockArgument argument,
319330
ArrayRef<MappedValue> values);
331+
LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
332+
ArrayRef<SmallVector<MappedValue>> mapping);
320333

321334
// Forward declarations to support limited visibility.
322335
class RegionScope;

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 123 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/IR/BuiltinAttributes.h"
2020
#include "mlir/IR/Diagnostics.h"
2121
#include "mlir/IR/Dominance.h"
22+
#include "mlir/IR/OpImplementation.h"
2223
#include "mlir/IR/OperationSupport.h"
2324
#include "mlir/IR/PatternMatch.h"
2425
#include "mlir/IR/Verifier.h"
@@ -834,19 +835,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
834835
// CollectMatchingOp
835836
//===----------------------------------------------------------------------===//
836837

837-
/// Applies matcher operations from the given `block` assigning `op` as the
838-
/// payload of the block's first argument. Updates `state` accordingly. If any
839-
/// of the matcher produces a silenceable failure, discards it (printing the
840-
/// content to the debug output stream) and returns failure. If any of the
841-
/// matchers produces a definite failure, reports it and returns failure. If all
842-
/// matchers in the block succeed, populates `mappings` with the payload
843-
/// entities associated with the block terminator operands.
838+
/// Applies matcher operations from the given `block` using
839+
/// `blockArgumentMapping` to initialize block arguments. Updates `state`
840+
/// accordingly. If any of the matcher produces a silenceable failure, discards
841+
/// it (printing the content to the debug output stream) and returns failure. If
842+
/// any of the matchers produces a definite failure, reports it and returns
843+
/// failure. If all matchers in the block succeed, populates `mappings` with the
844+
/// payload entities associated with the block terminator operands. Note that
845+
/// `mappings` will be cleared before that.
844846
static DiagnosedSilenceableFailure
845-
matchBlock(Block &block, Operation *op, transform::TransformState &state,
847+
matchBlock(Block &block,
848+
ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
849+
transform::TransformState &state,
846850
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
847851
assert(block.getParent() && "cannot match using a detached block");
848852
auto matchScope = state.make_region_scope(*block.getParent());
849-
if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
853+
if (failed(
854+
state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
850855
return DiagnosedSilenceableFailure::definiteFailure();
851856

852857
for (Operation &match : block.without_terminator()) {
@@ -866,6 +871,9 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
866871
// Remember the values mapped to the terminator operands so we can
867872
// forward them to the action.
868873
ValueRange yieldedValues = block.getTerminator()->getOperands();
874+
// Our contract with the caller is that the mappings will contain only the
875+
// newly mapped values, clear the rest.
876+
mappings.clear();
869877
transform::detail::prepareValueMappings(mappings, yieldedValues, state);
870878
return DiagnosedSilenceableFailure::success();
871879
}
@@ -915,8 +923,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
915923

916924
// Try matching.
917925
SmallVector<SmallVector<MappedValue>> mappings;
918-
DiagnosedSilenceableFailure diag =
919-
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
926+
SmallVector<transform::MappedValue> inputMapping({op});
927+
DiagnosedSilenceableFailure diag = matchBlock(
928+
matcher.getFunctionBody().front(),
929+
ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
930+
mappings);
920931
if (diag.isDefiniteFailure())
921932
return WalkResult::interrupt();
922933
if (diag.isSilenceableFailure()) {
@@ -1001,6 +1012,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
10011012
// ForeachMatchOp
10021013
//===----------------------------------------------------------------------===//
10031014

1015+
// This is fine because nothing is actually consumed by this op.
1016+
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1017+
10041018
DiagnosedSilenceableFailure
10051019
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10061020
transform::TransformResults &results,
@@ -1030,6 +1044,18 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10301044

10311045
DiagnosedSilenceableFailure overallDiag =
10321046
DiagnosedSilenceableFailure::success();
1047+
1048+
SmallVector<SmallVector<MappedValue>> matchInputMapping;
1049+
SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1050+
SmallVector<SmallVector<MappedValue>> actionResultMapping;
1051+
// Explicitly add the mapping for the first block argument (the op being
1052+
// matched).
1053+
matchInputMapping.emplace_back();
1054+
transform::detail::prepareValueMappings(matchInputMapping,
1055+
getForwardedInputs(), state);
1056+
SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1057+
actionResultMapping.resize(getForwardedOutputs().size());
1058+
10331059
for (Operation *root : state.getPayloadOps(getRoot())) {
10341060
WalkResult walkResult = root->walk([&](Operation *op) {
10351061
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1070,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10441070
llvm::dbgs() << " @" << op << "\n";
10451071
});
10461072

1073+
firstMatchArgument.clear();
1074+
firstMatchArgument.push_back(op);
1075+
10471076
// Try all the match/action pairs until the first successful match.
10481077
for (auto [matcher, action] : matchActionPairs) {
1049-
SmallVector<SmallVector<MappedValue>> mappings;
10501078
DiagnosedSilenceableFailure diag =
1051-
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
1079+
matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1080+
state, matchOutputMapping);
10521081
if (diag.isDefiniteFailure())
10531082
return WalkResult::interrupt();
10541083
if (diag.isSilenceableFailure()) {
@@ -1058,10 +1087,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10581087
}
10591088

10601089
auto scope = state.make_region_scope(action.getFunctionBody());
1061-
for (auto &&[arg, map] : llvm::zip_equal(
1062-
action.getFunctionBody().front().getArguments(), mappings)) {
1063-
if (failed(state.mapBlockArgument(arg, map)))
1064-
return WalkResult::interrupt();
1090+
if (failed(state.mapBlockArguments(
1091+
action.getFunctionBody().front().getArguments(),
1092+
matchOutputMapping))) {
1093+
return WalkResult::interrupt();
10651094
}
10661095

10671096
for (Operation &transform :
@@ -1082,6 +1111,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10821111
continue;
10831112
}
10841113
}
1114+
if (failed(detail::appendValueMappings(
1115+
MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1116+
action.getFunctionBody().front().getTerminator()->getOperands(),
1117+
state, getFlattenResults()))) {
1118+
emitDefiniteFailure()
1119+
<< "action @" << action.getName()
1120+
<< " has results associated with multiple payload entities, "
1121+
"but flattening was not requested";
1122+
return WalkResult::interrupt();
1123+
}
10851124
break;
10861125
}
10871126
return WalkResult::advance();
@@ -1096,9 +1135,21 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10961135
// by actions, are invalidated.
10971136
results.set(llvm::cast<OpResult>(getUpdated()),
10981137
state.getPayloadOps(getRoot()));
1138+
for (auto &&[result, mapping] :
1139+
llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1140+
results.setMappedValues(result, mapping);
1141+
}
10991142
return overallDiag;
11001143
}
11011144

1145+
void transform::ForeachMatchOp::getAsmResultNames(
1146+
OpAsmSetValueNameFn setNameFn) {
1147+
setNameFn(getUpdated(), "updated_root");
1148+
for (Value v : getForwardedOutputs()) {
1149+
setNameFn(v, "yielded");
1150+
}
1151+
}
1152+
11021153
void transform::ForeachMatchOp::getEffects(
11031154
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
11041155
// Bail if invalid.
@@ -1108,7 +1159,8 @@ void transform::ForeachMatchOp::getEffects(
11081159
}
11091160

11101161
consumesHandle(getRoot(), effects);
1111-
producesHandle(getUpdated(), effects);
1162+
onlyReadsHandle(getForwardedInputs(), effects);
1163+
producesHandle(getResults(), effects);
11121164
modifiesPayload(effects);
11131165
}
11141166

@@ -1224,6 +1276,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12241276
StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
12251277
for (auto &&[matcher, action] :
12261278
llvm::zip_equal(getMatchers(), getActions())) {
1279+
// Presence and typing.
12271280
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
12281281
symbolTable.lookupNearestSymbolFrom(getOperation(),
12291282
cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1303,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12501303
return failure();
12511304
}
12521305

1253-
ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
1254-
ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1306+
// Input -> matcher forwarding.
1307+
TypeRange operandTypes = getOperandTypes();
1308+
TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1309+
if (operandTypes.size() != matcherArguments.size()) {
1310+
InFlightDiagnostic diag =
1311+
emitError() << "the number of operands (" << operandTypes.size()
1312+
<< ") doesn't match the number of matcher arguments ("
1313+
<< matcherArguments.size() << ") for " << matcher;
1314+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1315+
return diag;
1316+
}
1317+
for (auto &&[i, operand, argument] :
1318+
llvm::enumerate(operandTypes, matcherArguments)) {
1319+
if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1320+
InFlightDiagnostic diag =
1321+
emitOpError()
1322+
<< "does not expect matcher symbol to consume its operand #" << i;
1323+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1324+
return diag;
1325+
}
1326+
1327+
if (implementSameTransformInterface(operand, argument))
1328+
continue;
1329+
1330+
InFlightDiagnostic diag =
1331+
emitError()
1332+
<< "mismatching type interfaces for operand and matcher argument #"
1333+
<< i << " of matcher " << matcher;
1334+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1335+
return diag;
1336+
}
1337+
1338+
// Matcher -> action forwarding.
1339+
TypeRange matcherResults = matcherSymbol.getResultTypes();
1340+
TypeRange actionArguments = actionSymbol.getArgumentTypes();
12551341
if (matcherResults.size() != actionArguments.size()) {
12561342
return emitError() << "mismatching number of matcher results and "
12571343
"action arguments between "
@@ -1265,31 +1351,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12651351

12661352
return emitError() << "mismatching type interfaces for matcher result "
12671353
"and action argument #"
1268-
<< i;
1354+
<< i << "of matcher " << matcher << " and action "
1355+
<< action;
12691356
}
12701357

1271-
if (!actionSymbol.getResultTypes().empty()) {
1358+
// Action -> result forwarding.
1359+
TypeRange actionResults = actionSymbol.getResultTypes();
1360+
auto resultTypes = TypeRange(getResultTypes()).drop_front();
1361+
if (actionResults.size() != resultTypes.size()) {
12721362
InFlightDiagnostic diag =
1273-
emitError() << "action symbol is not expected to have results";
1363+
emitError() << "the number of action results ("
1364+
<< actionResults.size() << ") for " << action
1365+
<< " doesn't match the number of extra op results ("
1366+
<< resultTypes.size() << ")";
12741367
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
12751368
return diag;
12761369
}
1370+
for (auto &&[i, resultType, actionType] :
1371+
llvm::enumerate(resultTypes, actionResults)) {
1372+
if (implementSameTransformInterface(resultType, actionType))
1373+
continue;
12771374

1278-
if (matcherSymbol.getArgumentTypes().size() != 1 ||
1279-
!implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
1280-
getRoot().getType())) {
1281-
InFlightDiagnostic diag =
1282-
emitOpError() << "expects matcher symbol to have one argument with "
1283-
"the same transform interface as the first operand";
1284-
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1285-
return diag;
1286-
}
1287-
1288-
if (matcherSymbol.getArgAttr(0, consumedAttr)) {
12891375
InFlightDiagnostic diag =
1290-
emitOpError()
1291-
<< "does not expect matcher symbol to consume its operand";
1292-
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1376+
emitError() << "mismatching type interfaces for action result #" << i
1377+
<< " of action " << action << " and op result";
1378+
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
12931379
return diag;
12941380
}
12951381
}

0 commit comments

Comments
 (0)