Skip to content

Commit 4a89c1b

Browse files
committed
[mlir] make transform.foreach_match forward arguments
It may be useful to have access to additional handles or parameters when performing matches and actions in `foreach_match`, for example, to parameterize the matcher by rank or restrict it in a non-trivial way. Enable `foreach_match` to forward additional handles from operands to matcher symbols and from action symbols to results.
1 parent 5445a35 commit 4a89c1b

File tree

6 files changed

+384
-69
lines changed

6 files changed

+384
-69
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,8 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
512512
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
513513
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
514514
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
515-
DeclareOpInterfaceMethods<TransformOpInterface>]> {
515+
DeclareOpInterfaceMethods<TransformOpInterface,
516+
["allowsRepeatedHandleOperands"]>]> {
516517
let summary = "Applies named sequences when a named matcher succeeds";
517518
let description = [{
518519
Given a pair of co-indexed lists of transform dialect symbols (such as
@@ -528,25 +529,31 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
528529
the following matchers are not applied to the same payload operation. If the
529530
action succeeds, the next payload operation in walk order is matched. If it
530531
fails, both silenceable and definite errors are propagated as the result of
531-
this op.
532-
533-
The matcher symbol must take one operand of a type that implements the same
534-
transform dialect interface as the `root` operand (a check is performed at
535-
application time to see if the associated payload satisfies the constraints
536-
of the actual type). It must not consume the operand as multiple matchers
532+
this op; propagation of silenceable errors is postponed until the end of the
533+
walk.
534+
535+
The matcher symbol must take at least one operand of a type that implements
536+
the same transform dialect interface as the `root` operand (a check is
537+
performed at application time to see if the associated payload satisfies the
538+
constraints of the actual type), and may take additional operands with a
539+
similar type requirement. It must not consume operands as multiple matchers
537540
may be applied. The matcher may produce any number of results. The action
538541
symbol paired with the matcher must take the same number of arguments as the
539542
matcher has results, and these arguments must implement the same transform
540543
dialect interfaces, but not necessarily have the exact same type (again, a
541544
check is performed at application time to see if the associated payload
542-
satisfies the constraints of actual types on both sides). The action symbol
543-
may not have results. The actions are expected to only modify payload
544-
operations nested in the `root` payload operations associated with the
545-
operand of this transform operation. Furhermore, the actions may not modify
546-
operations outside of the currently matched payload operation, e.g., they
547-
may not modify sibling or parent operations. If such behavior is desired,
548-
the parent must be matched first and the nested operations obtained by
549-
traversing the IR from the parent. This is due to the matching being
545+
satisfies the constraints of actual types on both sides).
546+
547+
The action symbol may have results that are accumulated from all actions and
548+
returned from the `foreach_match` operation on success. Unless the
549+
`flatten_results` attribute is present, each action result must be
550+
associated with exactly one payload entity. The actions are expected to only
551+
modify payload operations nested in the `root` payload operations associated
552+
with the operand of this transform operation. Furthermore, the actions may
553+
not modify operations outside of the currently matched payload operation,
554+
e.g., they may not modify sibling or parent operations. If such behavior is
555+
desired, the parent must be matched first and the nested operations obtained
556+
by traversing the IR from the parent. This is due to the matching being
550557
performed as a post-order IR walk.
551558

552559
This operation consumes the operand and produces a new handle associated
@@ -573,19 +580,26 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
573580
produced a definite failure.
574581
}];
575582

576-
let arguments = (ins TransformHandleTypeInterface:$root,
577-
UnitAttr:$restrict_root,
578-
SymbolRefArrayAttr:$matchers,
579-
SymbolRefArrayAttr:$actions);
580-
let results = (outs TransformHandleTypeInterface:$updated);
583+
let arguments =
584+
(ins TransformHandleTypeInterface:$root,
585+
Variadic<Transform_AnyHandleOrParamType>:$forwarded_inputs,
586+
UnitAttr:$restrict_root,
587+
UnitAttr:$flatten_results,
588+
SymbolRefArrayAttr:$matchers,
589+
SymbolRefArrayAttr:$actions);
590+
let results =
591+
(outs TransformHandleTypeInterface:$updated,
592+
Variadic<Transform_AnyHandleOrParamType>:$forwarded_outputs);
581593

582594
let assemblyFormat = [{
583-
(`restrict_root` $restrict_root^)?
595+
oilist( `restrict_root` $restrict_root
596+
| `flatten_results` $flatten_results
597+
)
584598
`in`
585-
$root
599+
$root (`,` $forwarded_inputs^)?
586600
custom<ForeachMatchSymbols>($matchers, $actions)
587601
attr-dict
588-
`:` functional-type($root, $updated)
602+
`:` functional-type(operands, results)
589603
}];
590604

591605
let hasVerifier = 1;

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ void getPotentialTopLevelEffects(
5252
/// Verification hook for TransformOpInterface.
5353
LogicalResult verifyTransformOpInterface(Operation *op);
5454

55+
/// Appends the entities associated with the given transform values in `state`
56+
/// to the pre-existing list of mappings. The array of mappings must have as
57+
/// many elements as values. If `flatten` is set, multiple values may be
58+
/// associated with each transform value, and this always succeeds. Otherwise,
59+
/// checks that each value has exactly one mapping associated and return failure
60+
/// otherwise.
61+
LogicalResult appendValueMappings(
62+
MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
63+
ValueRange values, const transform::TransformState &state,
64+
bool flatten = true);
65+
5566
/// Populates `mappings` with mapped values associated with the given transform
5667
/// IR values in the given `state`.
5768
void prepareValueMappings(
@@ -317,6 +328,8 @@ class TransformState {
317328
}
318329
LogicalResult mapBlockArgument(BlockArgument argument,
319330
ArrayRef<MappedValue> values);
331+
LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
332+
ArrayRef<SmallVector<MappedValue>> mapping);
320333

321334
// Forward declarations to support limited visibility.
322335
class RegionScope;

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 114 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -834,19 +834,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
834834
// CollectMatchingOp
835835
//===----------------------------------------------------------------------===//
836836

837-
/// Applies matcher operations from the given `block` assigning `op` as the
838-
/// payload of the block's first argument. Updates `state` accordingly. If any
839-
/// of the matcher produces a silenceable failure, discards it (printing the
840-
/// content to the debug output stream) and returns failure. If any of the
841-
/// matchers produces a definite failure, reports it and returns failure. If all
842-
/// matchers in the block succeed, populates `mappings` with the payload
843-
/// entities associated with the block terminator operands.
837+
/// Applies matcher operations from the given `block` using
838+
/// `blockArgumentMapping` to initialize block arguments. Updates `state`
839+
/// accordingly. If any of the matcher produces a silenceable failure, discards
840+
/// it (printing the content to the debug output stream) and returns failure. If
841+
/// any of the matchers produces a definite failure, reports it and returns
842+
/// failure. If all matchers in the block succeed, populates `mappings` with the
843+
/// payload entities associated with the block terminator operands. Note that
844+
/// `mappings` will be cleared before that.
844845
static DiagnosedSilenceableFailure
845-
matchBlock(Block &block, Operation *op, transform::TransformState &state,
846+
matchBlock(Block &block,
847+
ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
848+
transform::TransformState &state,
846849
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
847850
assert(block.getParent() && "cannot match using a detached block");
848851
auto matchScope = state.make_region_scope(*block.getParent());
849-
if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
852+
if (failed(
853+
state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
850854
return DiagnosedSilenceableFailure::definiteFailure();
851855

852856
for (Operation &match : block.without_terminator()) {
@@ -866,6 +870,9 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
866870
// Remember the values mapped to the terminator operands so we can
867871
// forward them to the action.
868872
ValueRange yieldedValues = block.getTerminator()->getOperands();
873+
// Our contract with the caller is that the mappings will contain only the
874+
// newly mapped values, clear the rest.
875+
mappings.clear();
869876
transform::detail::prepareValueMappings(mappings, yieldedValues, state);
870877
return DiagnosedSilenceableFailure::success();
871878
}
@@ -915,8 +922,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
915922

916923
// Try matching.
917924
SmallVector<SmallVector<MappedValue>> mappings;
918-
DiagnosedSilenceableFailure diag =
919-
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
925+
SmallVector<transform::MappedValue> inputMapping({op});
926+
DiagnosedSilenceableFailure diag = matchBlock(
927+
matcher.getFunctionBody().front(),
928+
ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
929+
mappings);
920930
if (diag.isDefiniteFailure())
921931
return WalkResult::interrupt();
922932
if (diag.isSilenceableFailure()) {
@@ -1001,6 +1011,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
10011011
// ForeachMatchOp
10021012
//===----------------------------------------------------------------------===//
10031013

1014+
// This is fine because nothing is actually consumed by this op.
1015+
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1016+
10041017
DiagnosedSilenceableFailure
10051018
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10061019
transform::TransformResults &results,
@@ -1030,6 +1043,18 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10301043

10311044
DiagnosedSilenceableFailure overallDiag =
10321045
DiagnosedSilenceableFailure::success();
1046+
1047+
SmallVector<SmallVector<MappedValue>> matchInputMapping;
1048+
SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1049+
SmallVector<SmallVector<MappedValue>> actionResultMapping;
1050+
// Explicitly add the mapping for the first block argument (the op being
1051+
// matched).
1052+
matchInputMapping.emplace_back();
1053+
transform::detail::prepareValueMappings(matchInputMapping,
1054+
getForwardedInputs(), state);
1055+
SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1056+
actionResultMapping.resize(getForwardedOutputs().size());
1057+
10331058
for (Operation *root : state.getPayloadOps(getRoot())) {
10341059
WalkResult walkResult = root->walk([&](Operation *op) {
10351060
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1069,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10441069
llvm::dbgs() << " @" << op << "\n";
10451070
});
10461071

1072+
firstMatchArgument.clear();
1073+
firstMatchArgument.push_back(op);
1074+
10471075
// Try all the match/action pairs until the first successful match.
10481076
for (auto [matcher, action] : matchActionPairs) {
1049-
SmallVector<SmallVector<MappedValue>> mappings;
10501077
DiagnosedSilenceableFailure diag =
1051-
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
1078+
matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1079+
state, matchOutputMapping);
10521080
if (diag.isDefiniteFailure())
10531081
return WalkResult::interrupt();
10541082
if (diag.isSilenceableFailure()) {
@@ -1058,10 +1086,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10581086
}
10591087

10601088
auto scope = state.make_region_scope(action.getFunctionBody());
1061-
for (auto &&[arg, map] : llvm::zip_equal(
1062-
action.getFunctionBody().front().getArguments(), mappings)) {
1063-
if (failed(state.mapBlockArgument(arg, map)))
1064-
return WalkResult::interrupt();
1089+
if (failed(state.mapBlockArguments(
1090+
action.getFunctionBody().front().getArguments(),
1091+
matchOutputMapping))) {
1092+
return WalkResult::interrupt();
10651093
}
10661094

10671095
for (Operation &transform :
@@ -1082,6 +1110,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10821110
continue;
10831111
}
10841112
}
1113+
if (failed(detail::appendValueMappings(
1114+
MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1115+
action.getFunctionBody().front().getTerminator()->getOperands(),
1116+
state, getFlattenResults()))) {
1117+
emitDefiniteFailure()
1118+
<< "action @" << action.getName()
1119+
<< " has results associated with multiple payload entities, "
1120+
"but flattening was not requested";
1121+
return WalkResult::interrupt();
1122+
}
10851123
break;
10861124
}
10871125
return WalkResult::advance();
@@ -1096,6 +1134,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10961134
// by actions, are invalidated.
10971135
results.set(llvm::cast<OpResult>(getUpdated()),
10981136
state.getPayloadOps(getRoot()));
1137+
for (auto &&[result, mapping] :
1138+
llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1139+
results.setMappedValues(result, mapping);
1140+
}
10991141
return overallDiag;
11001142
}
11011143

@@ -1108,7 +1150,8 @@ void transform::ForeachMatchOp::getEffects(
11081150
}
11091151

11101152
consumesHandle(getRoot(), effects);
1111-
producesHandle(getUpdated(), effects);
1153+
onlyReadsHandle(getForwardedInputs(), effects);
1154+
producesHandle(getResults(), effects);
11121155
modifiesPayload(effects);
11131156
}
11141157

@@ -1224,6 +1267,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12241267
StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
12251268
for (auto &&[matcher, action] :
12261269
llvm::zip_equal(getMatchers(), getActions())) {
1270+
// Presence and typing.
12271271
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
12281272
symbolTable.lookupNearestSymbolFrom(getOperation(),
12291273
cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1294,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12501294
return failure();
12511295
}
12521296

1253-
ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
1254-
ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1297+
// Input -> matcher forwarding.
1298+
TypeRange operandTypes = getOperandTypes();
1299+
TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1300+
if (operandTypes.size() != matcherArguments.size()) {
1301+
InFlightDiagnostic diag =
1302+
emitError() << "the number of operands (" << operandTypes.size()
1303+
<< ") doesn't match the number of matcher arguments ("
1304+
<< matcherArguments.size() << ") for " << matcher;
1305+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1306+
return diag;
1307+
}
1308+
for (auto &&[i, operand, argument] :
1309+
llvm::enumerate(operandTypes, matcherArguments)) {
1310+
if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1311+
InFlightDiagnostic diag =
1312+
emitOpError()
1313+
<< "does not expect matcher symbol to consume its operand #" << i;
1314+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1315+
return diag;
1316+
}
1317+
1318+
if (implementSameTransformInterface(operand, argument))
1319+
continue;
1320+
1321+
InFlightDiagnostic diag =
1322+
emitError()
1323+
<< "mismatching type interfaces for operand and matcher argument #"
1324+
<< i << " of matcher " << matcher;
1325+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1326+
return diag;
1327+
}
1328+
1329+
// Matcher -> action forwarding.
1330+
TypeRange matcherResults = matcherSymbol.getResultTypes();
1331+
TypeRange actionArguments = actionSymbol.getArgumentTypes();
12551332
if (matcherResults.size() != actionArguments.size()) {
12561333
return emitError() << "mismatching number of matcher results and "
12571334
"action arguments between "
@@ -1265,31 +1342,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12651342

12661343
return emitError() << "mismatching type interfaces for matcher result "
12671344
"and action argument #"
1268-
<< i;
1345+
<< i << "of matcher " << matcher << " and action "
1346+
<< action;
12691347
}
12701348

1271-
if (!actionSymbol.getResultTypes().empty()) {
1349+
// Action -> result forwarding.
1350+
TypeRange actionResults = actionSymbol.getResultTypes();
1351+
auto resultTypes = TypeRange(getResultTypes()).drop_front();
1352+
if (actionResults.size() != resultTypes.size()) {
12721353
InFlightDiagnostic diag =
1273-
emitError() << "action symbol is not expected to have results";
1354+
emitError() << "the number of action results ("
1355+
<< actionResults.size() << ") for " << action
1356+
<< " doesn't match the number of extra op results ("
1357+
<< resultTypes.size() << ")";
12741358
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
12751359
return diag;
12761360
}
1361+
for (auto &&[i, resultType, actionType] :
1362+
llvm::enumerate(resultTypes, actionResults)) {
1363+
if (implementSameTransformInterface(resultType, actionType))
1364+
continue;
12771365

1278-
if (matcherSymbol.getArgumentTypes().size() != 1 ||
1279-
!implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
1280-
getRoot().getType())) {
1281-
InFlightDiagnostic diag =
1282-
emitOpError() << "expects matcher symbol to have one argument with "
1283-
"the same transform interface as the first operand";
1284-
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1285-
return diag;
1286-
}
1287-
1288-
if (matcherSymbol.getArgAttr(0, consumedAttr)) {
12891366
InFlightDiagnostic diag =
1290-
emitOpError()
1291-
<< "does not expect matcher symbol to consume its operand";
1292-
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1367+
emitError() << "mismatching type interfaces for action result #" << i
1368+
<< " of action " << action << " and op result";
1369+
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
12931370
return diag;
12941371
}
12951372
}

0 commit comments

Comments
 (0)