Skip to content

Commit b750167

Browse files
committed
[mlir] make transform.foreach_match forward arguments
It may be useful to have access to additional handles or parameters when performing matches and actions in `foreach_match`, for example, to parameterize the matcher by rank or restrict it in a non-trivial way. Enable `foreach_match` to forward additional handles from operands to matcher symbols and from action symbols to results.
1 parent 69703b1 commit b750167

File tree

6 files changed

+380
-69
lines changed

6 files changed

+380
-69
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,8 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
512512
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
513513
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
514514
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
515-
DeclareOpInterfaceMethods<TransformOpInterface>]> {
515+
DeclareOpInterfaceMethods<TransformOpInterface,
516+
["allowsRepeatedHandleOperands"]>]> {
516517
let summary = "Applies named sequences when a named matcher succeeds";
517518
let description = [{
518519
Given a pair of co-indexed lists of transform dialect symbols (such as
@@ -528,25 +529,31 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
528529
the following matchers are not applied to the same payload operation. If the
529530
action succeeds, the next payload operation in walk order is matched. If it
530531
fails, both silenceable and definite errors are propagated as the result of
531-
this op.
532-
533-
The matcher symbol must take one operand of a type that implements the same
534-
transform dialect interface as the `root` operand (a check is performed at
535-
application time to see if the associated payload satisfies the constraints
536-
of the actual type). It must not consume the operand as multiple matchers
532+
this op; propagation of silenceable errors is postponed until the end of the
533+
walk.
534+
535+
The matcher symbol must take at least one operand of a type that implements
536+
the same transform dialect interface as the `root` operand (a check is
537+
performed at application time to see if the associated payload satisfies the
538+
constraints of the actual type), and may take additional operands with a
539+
similar type requirement. It must not consume operands as multiple matchers
537540
may be applied. The matcher may produce any number of results. The action
538541
symbol paired with the matcher must take the same number of arguments as the
539542
matcher has results, and these arguments must implement the same transform
540543
dialect interfaces, but not necessarily have the exact same type (again, a
541544
check is performed at application time to see if the associated payload
542-
satisfies the constraints of actual types on both sides). The action symbol
543-
may not have results. The actions are expected to only modify payload
544-
operations nested in the `root` payload operations associated with the
545-
operand of this transform operation. Furhermore, the actions may not modify
546-
operations outside of the currently matched payload operation, e.g., they
547-
may not modify sibling or parent operations. If such behavior is desired,
548-
the parent must be matched first and the nested operations obtained by
549-
traversing the IR from the parent. This is due to the matching being
545+
satisfies the constraints of actual types on both sides).
546+
547+
The action symbol may have results that are accumulated from all actions and
548+
returned from the `foreach_match` operation on success. Unless the
549+
`flatten_results` attribute is present, each action result must be
550+
associated with exactly one payload entity. The actions are expected to only
551+
modify payload operations nested in the `root` payload operations associated
552+
with the operand of this transform operation. Furthermore, the actions may
553+
not modify operations outside of the currently matched payload operation,
554+
e.g., they may not modify sibling or parent operations. If such behavior is
555+
desired, the parent must be matched first and the nested operations obtained
556+
by traversing the IR from the parent. This is due to the matching being
550557
performed as a post-order IR walk.
551558

552559
This operation consumes the operand and produces a new handle associated
@@ -573,19 +580,26 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
573580
produced a definite failure.
574581
}];
575582

576-
let arguments = (ins TransformHandleTypeInterface:$root,
577-
UnitAttr:$restrict_root,
578-
SymbolRefArrayAttr:$matchers,
579-
SymbolRefArrayAttr:$actions);
580-
let results = (outs TransformHandleTypeInterface:$updated);
583+
let arguments =
584+
(ins TransformHandleTypeInterface:$root,
585+
Variadic<Transform_AnyHandleOrParamType>:$forwarded_inputs,
586+
UnitAttr:$restrict_root,
587+
UnitAttr:$flatten_results,
588+
SymbolRefArrayAttr:$matchers,
589+
SymbolRefArrayAttr:$actions);
590+
let results =
591+
(outs TransformHandleTypeInterface:$updated,
592+
Variadic<Transform_AnyHandleOrParamType>:$forwarded_outputs);
581593

582594
let assemblyFormat = [{
583-
(`restrict_root` $restrict_root^)?
595+
oilist( `restrict_root` $restrict_root
596+
| `flatten_results` $flatten_results
597+
)
584598
`in`
585-
$root
599+
$root (`,` $forwarded_inputs^)?
586600
custom<ForeachMatchSymbols>($matchers, $actions)
587601
attr-dict
588-
`:` functional-type($root, $updated)
602+
`:` functional-type(operands, results)
589603
}];
590604

591605
let hasVerifier = 1;

mlir/include/mlir/Dialect/Transform/Interfaces/TransformInterfaces.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,17 @@ void getPotentialTopLevelEffects(
5252
/// Verification hook for TransformOpInterface.
5353
LogicalResult verifyTransformOpInterface(Operation *op);
5454

55+
/// Appends the entities associated with the given transform. values in `state`
56+
/// to the pre-existing list of mappings. The array of mappings must have as
57+
/// many elements as values. If `flatten` is set, multiple values may be
58+
/// associated with each transform value, and this always succeeds. Otherwise,
59+
/// checks that each value has exactly one mapping associated and return failure
60+
/// otherwise.
61+
LogicalResult appendValueMappings(
62+
MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
63+
ValueRange values, const transform::TransformState &state,
64+
bool flatten = true);
65+
5566
/// Populates `mappings` with mapped values associated with the given transform
5667
/// IR values in the given `state`.
5768
void prepareValueMappings(
@@ -317,6 +328,8 @@ class TransformState {
317328
}
318329
LogicalResult mapBlockArgument(BlockArgument argument,
319330
ArrayRef<MappedValue> values);
331+
LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
332+
ArrayRef<SmallVector<MappedValue>> mapping);
320333

321334
// Forward declarations to support limited visibility.
322335
class RegionScope;

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 110 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -834,19 +834,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
834834
// CollectMatchingOp
835835
//===----------------------------------------------------------------------===//
836836

837-
/// Applies matcher operations from the given `block` assigning `op` as the
838-
/// payload of the block's first argument. Updates `state` accordingly. If any
839-
/// of the matcher produces a silenceable failure, discards it (printing the
840-
/// content to the debug output stream) and returns failure. If any of the
841-
/// matchers produces a definite failure, reports it and returns failure. If all
842-
/// matchers in the block succeed, populates `mappings` with the payload
843-
/// entities associated with the block terminator operands.
837+
/// Applies matcher operations from the given `block` using
838+
/// `blockArgumentMapping` to initialize block arguments. Updates `state`
839+
/// accordingly. If any of the matcher produces a silenceable failure, discards
840+
/// it (printing the content to the debug output stream) and returns failure. If
841+
/// any of the matchers produces a definite failure, reports it and returns
842+
/// failure. If all matchers in the block succeed, populates `mappings` with the
843+
/// payload entities associated with the block terminator operands. Note that
844+
/// `mappings` will be cleared before that.
844845
static DiagnosedSilenceableFailure
845-
matchBlock(Block &block, Operation *op, transform::TransformState &state,
846+
matchBlock(Block &block,
847+
ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
848+
transform::TransformState &state,
846849
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
847850
assert(block.getParent() && "cannot match using a detached block");
848851
auto matchScope = state.make_region_scope(*block.getParent());
849-
if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
852+
if (failed(
853+
state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
850854
return DiagnosedSilenceableFailure::definiteFailure();
851855

852856
for (Operation &match : block.without_terminator()) {
@@ -866,6 +870,7 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
866870
// Remember the values mapped to the terminator operands so we can
867871
// forward them to the action.
868872
ValueRange yieldedValues = block.getTerminator()->getOperands();
873+
mappings.clear();
869874
transform::detail::prepareValueMappings(mappings, yieldedValues, state);
870875
return DiagnosedSilenceableFailure::success();
871876
}
@@ -915,8 +920,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
915920

916921
// Try matching.
917922
SmallVector<SmallVector<MappedValue>> mappings;
918-
DiagnosedSilenceableFailure diag =
919-
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
923+
SmallVector<transform::MappedValue> inputMapping({op});
924+
DiagnosedSilenceableFailure diag = matchBlock(
925+
matcher.getFunctionBody().front(),
926+
ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
927+
mappings);
920928
if (diag.isDefiniteFailure())
921929
return WalkResult::interrupt();
922930
if (diag.isSilenceableFailure()) {
@@ -1001,6 +1009,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
10011009
// ForeachMatchOp
10021010
//===----------------------------------------------------------------------===//
10031011

1012+
// This is fine because nothing is actually consumed by this op.
1013+
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }
1014+
10041015
DiagnosedSilenceableFailure
10051016
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10061017
transform::TransformResults &results,
@@ -1030,6 +1041,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10301041

10311042
DiagnosedSilenceableFailure overallDiag =
10321043
DiagnosedSilenceableFailure::success();
1044+
1045+
SmallVector<SmallVector<MappedValue>> matchInputMapping;
1046+
SmallVector<SmallVector<MappedValue>> matchOutputMapping;
1047+
SmallVector<SmallVector<MappedValue>> actionResultMapping;
1048+
matchInputMapping.emplace_back();
1049+
transform::detail::prepareValueMappings(matchInputMapping,
1050+
getForwardedInputs(), state);
1051+
SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
1052+
actionResultMapping.resize(getForwardedOutputs().size());
1053+
10331054
for (Operation *root : state.getPayloadOps(getRoot())) {
10341055
WalkResult walkResult = root->walk([&](Operation *op) {
10351056
// If getRestrictRoot is not present, skip over the root op itself so we
@@ -1044,11 +1065,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10441065
llvm::dbgs() << " @" << op << "\n";
10451066
});
10461067

1068+
firstMatchArgument.clear();
1069+
firstMatchArgument.push_back(op);
1070+
10471071
// Try all the match/action pairs until the first successful match.
10481072
for (auto [matcher, action] : matchActionPairs) {
1049-
SmallVector<SmallVector<MappedValue>> mappings;
10501073
DiagnosedSilenceableFailure diag =
1051-
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
1074+
matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
1075+
state, matchOutputMapping);
10521076
if (diag.isDefiniteFailure())
10531077
return WalkResult::interrupt();
10541078
if (diag.isSilenceableFailure()) {
@@ -1058,10 +1082,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10581082
}
10591083

10601084
auto scope = state.make_region_scope(action.getFunctionBody());
1061-
for (auto &&[arg, map] : llvm::zip_equal(
1062-
action.getFunctionBody().front().getArguments(), mappings)) {
1063-
if (failed(state.mapBlockArgument(arg, map)))
1064-
return WalkResult::interrupt();
1085+
if (failed(state.mapBlockArguments(
1086+
action.getFunctionBody().front().getArguments(),
1087+
matchOutputMapping))) {
1088+
return WalkResult::interrupt();
10651089
}
10661090

10671091
for (Operation &transform :
@@ -1082,6 +1106,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10821106
continue;
10831107
}
10841108
}
1109+
if (failed(detail::appendValueMappings(
1110+
MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
1111+
action.getFunctionBody().front().getTerminator()->getOperands(),
1112+
state, getFlattenResults()))) {
1113+
emitDefiniteFailure()
1114+
<< "action @" << action.getName()
1115+
<< " has results associated with multiple payload entities, "
1116+
"but flattening was not requested";
1117+
return WalkResult::interrupt();
1118+
}
10851119
break;
10861120
}
10871121
return WalkResult::advance();
@@ -1096,6 +1130,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
10961130
// by actions, are invalidated.
10971131
results.set(llvm::cast<OpResult>(getUpdated()),
10981132
state.getPayloadOps(getRoot()));
1133+
for (auto &&[result, mapping] :
1134+
llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
1135+
results.setMappedValues(result, mapping);
1136+
}
10991137
return overallDiag;
11001138
}
11011139

@@ -1108,7 +1146,8 @@ void transform::ForeachMatchOp::getEffects(
11081146
}
11091147

11101148
consumesHandle(getRoot(), effects);
1111-
producesHandle(getUpdated(), effects);
1149+
onlyReadsHandle(getForwardedInputs(), effects);
1150+
producesHandle(getResults(), effects);
11121151
modifiesPayload(effects);
11131152
}
11141153

@@ -1224,6 +1263,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12241263
StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
12251264
for (auto &&[matcher, action] :
12261265
llvm::zip_equal(getMatchers(), getActions())) {
1266+
// Presence and typing.
12271267
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
12281268
symbolTable.lookupNearestSymbolFrom(getOperation(),
12291269
cast<SymbolRefAttr>(matcher)));
@@ -1250,8 +1290,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12501290
return failure();
12511291
}
12521292

1253-
ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
1254-
ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
1293+
// Input -> matcher forwarding.
1294+
TypeRange operandTypes = getOperandTypes();
1295+
TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
1296+
if (operandTypes.size() != matcherArguments.size()) {
1297+
InFlightDiagnostic diag =
1298+
emitError() << "the number of operands (" << operandTypes.size()
1299+
<< ") doesn't match the number of matcher arguments ("
1300+
<< matcherArguments.size() << ") for " << matcher;
1301+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1302+
return diag;
1303+
}
1304+
for (auto &&[i, operand, argument] :
1305+
llvm::enumerate(operandTypes, matcherArguments)) {
1306+
if (matcherSymbol.getArgAttr(i, consumedAttr)) {
1307+
InFlightDiagnostic diag =
1308+
emitOpError()
1309+
<< "does not expect matcher symbol to consume its operand #" << i;
1310+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1311+
return diag;
1312+
}
1313+
1314+
if (implementSameTransformInterface(operand, argument))
1315+
continue;
1316+
1317+
InFlightDiagnostic diag =
1318+
emitError()
1319+
<< "mismatching type interfaces for operand and matcher argument #"
1320+
<< i << " of matcher " << matcher;
1321+
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1322+
return diag;
1323+
}
1324+
1325+
// Matcher -> action forwarding.
1326+
TypeRange matcherResults = matcherSymbol.getResultTypes();
1327+
TypeRange actionArguments = actionSymbol.getArgumentTypes();
12551328
if (matcherResults.size() != actionArguments.size()) {
12561329
return emitError() << "mismatching number of matcher results and "
12571330
"action arguments between "
@@ -1265,31 +1338,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
12651338

12661339
return emitError() << "mismatching type interfaces for matcher result "
12671340
"and action argument #"
1268-
<< i;
1341+
<< i << "of matcher " << matcher << " and action "
1342+
<< action;
12691343
}
12701344

1271-
if (!actionSymbol.getResultTypes().empty()) {
1345+
// Action -> result forwarding.
1346+
TypeRange actionResults = actionSymbol.getResultTypes();
1347+
auto resultTypes = TypeRange(getResultTypes()).drop_front();
1348+
if (actionResults.size() != resultTypes.size()) {
12721349
InFlightDiagnostic diag =
1273-
emitError() << "action symbol is not expected to have results";
1350+
emitError() << "the number of action results ("
1351+
<< actionResults.size() << ") for " << action
1352+
<< " doesn't match the number of extra op results ("
1353+
<< resultTypes.size() << ")";
12741354
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
12751355
return diag;
12761356
}
1357+
for (auto &&[i, resultType, actionType] :
1358+
llvm::enumerate(resultTypes, actionResults)) {
1359+
if (implementSameTransformInterface(resultType, actionType))
1360+
continue;
12771361

1278-
if (matcherSymbol.getArgumentTypes().size() != 1 ||
1279-
!implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
1280-
getRoot().getType())) {
1281-
InFlightDiagnostic diag =
1282-
emitOpError() << "expects matcher symbol to have one argument with "
1283-
"the same transform interface as the first operand";
1284-
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1285-
return diag;
1286-
}
1287-
1288-
if (matcherSymbol.getArgAttr(0, consumedAttr)) {
12891362
InFlightDiagnostic diag =
1290-
emitOpError()
1291-
<< "does not expect matcher symbol to consume its operand";
1292-
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
1363+
emitError() << "mismatching type interfaces for action result #" << i
1364+
<< " of action " << action << " and op result";
1365+
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
12931366
return diag;
12941367
}
12951368
}

0 commit comments

Comments
 (0)