Skip to content

[mlir] make transform.foreach_match forward arguments #89920

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 39 additions & 23 deletions mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,10 @@ def CollectMatchingOp : TransformDialectOp<"collect_matching", [
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
DeclareOpInterfaceMethods<TransformOpInterface,
["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<OpAsmOpInterface,
["getAsmResultNames"]>]> {
let summary = "Applies named sequences when a named matcher succeeds";
let description = [{
Given a pair of co-indexed lists of transform dialect symbols (such as
Expand All @@ -528,25 +531,31 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
the following matchers are not applied to the same payload operation. If the
action succeeds, the next payload operation in walk order is matched. If it
fails, both silenceable and definite errors are propagated as the result of
this op.

The matcher symbol must take one operand of a type that implements the same
transform dialect interface as the `root` operand (a check is performed at
application time to see if the associated payload satisfies the constraints
of the actual type). It must not consume the operand as multiple matchers
this op; propagation of silenceable errors is postponed until the end of the
walk.

The matcher symbol must take at least one operand of a type that implements
the same transform dialect interface as the `root` operand (a check is
performed at application time to see if the associated payload satisfies the
constraints of the actual type), and may take additional operands with a
similar type requirement. It must not consume operands as multiple matchers
may be applied. The matcher may produce any number of results. The action
symbol paired with the matcher must take the same number of arguments as the
matcher has results, and these arguments must implement the same transform
dialect interfaces, but not necessarily have the exact same type (again, a
check is performed at application time to see if the associated payload
satisfies the constraints of actual types on both sides). The action symbol
may not have results. The actions are expected to only modify payload
operations nested in the `root` payload operations associated with the
operand of this transform operation. Furhermore, the actions may not modify
operations outside of the currently matched payload operation, e.g., they
may not modify sibling or parent operations. If such behavior is desired,
the parent must be matched first and the nested operations obtained by
traversing the IR from the parent. This is due to the matching being
satisfies the constraints of actual types on both sides).

The action symbol may have results that are accumulated from all actions and
returned from the `foreach_match` operation on success. Unless the
`flatten_results` attribute is present, each action result must be
associated with exactly one payload entity. The actions are expected to only
modify payload operations nested in the `root` payload operations associated
with the operand of this transform operation. Furthermore, the actions may
not modify operations outside of the currently matched payload operation,
e.g., they may not modify sibling or parent operations. If such behavior is
desired, the parent must be matched first and the nested operations obtained
by traversing the IR from the parent. This is due to the matching being
performed as a post-order IR walk.

This operation consumes the operand and produces a new handle associated
Expand All @@ -573,19 +582,26 @@ def ForeachMatchOp : TransformDialectOp<"foreach_match", [
produced a definite failure.
}];

let arguments = (ins TransformHandleTypeInterface:$root,
UnitAttr:$restrict_root,
SymbolRefArrayAttr:$matchers,
SymbolRefArrayAttr:$actions);
let results = (outs TransformHandleTypeInterface:$updated);
let arguments =
(ins TransformHandleTypeInterface:$root,
Variadic<Transform_AnyHandleOrParamType>:$forwarded_inputs,
UnitAttr:$restrict_root,
UnitAttr:$flatten_results,
SymbolRefArrayAttr:$matchers,
SymbolRefArrayAttr:$actions);
let results =
(outs TransformHandleTypeInterface:$updated,
Variadic<Transform_AnyHandleOrParamType>:$forwarded_outputs);

let assemblyFormat = [{
(`restrict_root` $restrict_root^)?
oilist( `restrict_root` $restrict_root
| `flatten_results` $flatten_results
)
`in`
$root
$root (`,` $forwarded_inputs^)?
custom<ForeachMatchSymbols>($matchers, $actions)
attr-dict
`:` functional-type($root, $updated)
`:` functional-type(operands, results)
}];

let hasVerifier = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ void getPotentialTopLevelEffects(
/// Verification hook for TransformOpInterface.
LogicalResult verifyTransformOpInterface(Operation *op);

/// Appends the entities associated with the given transform values in `state`
/// to the pre-existing list of mappings. The array of mappings must have as
/// many elements as values. If `flatten` is set, multiple values may be
/// associated with each transform value, and this always succeeds. Otherwise,
/// checks that each value has exactly one mapping associated and return failure
/// otherwise.
LogicalResult appendValueMappings(
MutableArrayRef<SmallVector<transform::MappedValue>> mappings,
ValueRange values, const transform::TransformState &state,
bool flatten = true);

/// Populates `mappings` with mapped values associated with the given transform
/// IR values in the given `state`.
void prepareValueMappings(
Expand Down Expand Up @@ -317,6 +328,8 @@ class TransformState {
}
LogicalResult mapBlockArgument(BlockArgument argument,
ArrayRef<MappedValue> values);
LogicalResult mapBlockArguments(Block::BlockArgListType arguments,
ArrayRef<SmallVector<MappedValue>> mapping);

// Forward declarations to support limited visibility.
class RegionScope;
Expand Down
160 changes: 123 additions & 37 deletions mlir/lib/Dialect/Transform/IR/TransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Verifier.h"
Expand Down Expand Up @@ -834,19 +835,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
// CollectMatchingOp
//===----------------------------------------------------------------------===//

/// Applies matcher operations from the given `block` assigning `op` as the
/// payload of the block's first argument. Updates `state` accordingly. If any
/// of the matcher produces a silenceable failure, discards it (printing the
/// content to the debug output stream) and returns failure. If any of the
/// matchers produces a definite failure, reports it and returns failure. If all
/// matchers in the block succeed, populates `mappings` with the payload
/// entities associated with the block terminator operands.
/// Applies matcher operations from the given `block` using
/// `blockArgumentMapping` to initialize block arguments. Updates `state`
/// accordingly. If any of the matcher produces a silenceable failure, discards
/// it (printing the content to the debug output stream) and returns failure. If
/// any of the matchers produces a definite failure, reports it and returns
/// failure. If all matchers in the block succeed, populates `mappings` with the
/// payload entities associated with the block terminator operands. Note that
/// `mappings` will be cleared before that.
static DiagnosedSilenceableFailure
matchBlock(Block &block, Operation *op, transform::TransformState &state,
matchBlock(Block &block,
ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping,
transform::TransformState &state,
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) {
assert(block.getParent() && "cannot match using a detached block");
auto matchScope = state.make_region_scope(*block.getParent());
if (failed(state.mapBlockArgument(block.getArgument(0), {op})))
if (failed(
state.mapBlockArguments(block.getArguments(), blockArgumentMapping)))
return DiagnosedSilenceableFailure::definiteFailure();

for (Operation &match : block.without_terminator()) {
Expand All @@ -866,6 +871,9 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
// Remember the values mapped to the terminator operands so we can
// forward them to the action.
ValueRange yieldedValues = block.getTerminator()->getOperands();
// Our contract with the caller is that the mappings will contain only the
// newly mapped values, clear the rest.
mappings.clear();
transform::detail::prepareValueMappings(mappings, yieldedValues, state);
return DiagnosedSilenceableFailure::success();
}
Expand Down Expand Up @@ -915,8 +923,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,

// Try matching.
SmallVector<SmallVector<MappedValue>> mappings;
DiagnosedSilenceableFailure diag =
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
SmallVector<transform::MappedValue> inputMapping({op});
DiagnosedSilenceableFailure diag = matchBlock(
matcher.getFunctionBody().front(),
ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state,
mappings);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
Expand Down Expand Up @@ -1001,6 +1012,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses(
// ForeachMatchOp
//===----------------------------------------------------------------------===//

// This is fine because nothing is actually consumed by this op.
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; }

DiagnosedSilenceableFailure
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
Expand Down Expand Up @@ -1030,6 +1044,18 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,

DiagnosedSilenceableFailure overallDiag =
DiagnosedSilenceableFailure::success();

SmallVector<SmallVector<MappedValue>> matchInputMapping;
SmallVector<SmallVector<MappedValue>> matchOutputMapping;
SmallVector<SmallVector<MappedValue>> actionResultMapping;
// Explicitly add the mapping for the first block argument (the op being
// matched).
matchInputMapping.emplace_back();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this outide of the call to prepareValueMappings given that prepareValueMappings immediately resizes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because we also need to map the op that is being matched before the payloads that are forwarded. Added a comment.

transform::detail::prepareValueMappings(matchInputMapping,
getForwardedInputs(), state);
SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front();
actionResultMapping.resize(getForwardedOutputs().size());

for (Operation *root : state.getPayloadOps(getRoot())) {
WalkResult walkResult = root->walk([&](Operation *op) {
// If getRestrictRoot is not present, skip over the root op itself so we
Expand All @@ -1044,11 +1070,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
llvm::dbgs() << " @" << op << "\n";
});

firstMatchArgument.clear();
firstMatchArgument.push_back(op);

// Try all the match/action pairs until the first successful match.
for (auto [matcher, action] : matchActionPairs) {
SmallVector<SmallVector<MappedValue>> mappings;
DiagnosedSilenceableFailure diag =
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
matchBlock(matcher.getFunctionBody().front(), matchInputMapping,
state, matchOutputMapping);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.isSilenceableFailure()) {
Expand All @@ -1058,10 +1087,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
}

auto scope = state.make_region_scope(action.getFunctionBody());
for (auto &&[arg, map] : llvm::zip_equal(
action.getFunctionBody().front().getArguments(), mappings)) {
if (failed(state.mapBlockArgument(arg, map)))
return WalkResult::interrupt();
if (failed(state.mapBlockArguments(
action.getFunctionBody().front().getArguments(),
matchOutputMapping))) {
return WalkResult::interrupt();
}

for (Operation &transform :
Expand All @@ -1082,6 +1111,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
continue;
}
}
if (failed(detail::appendValueMappings(
MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping),
action.getFunctionBody().front().getTerminator()->getOperands(),
state, getFlattenResults()))) {
emitDefiniteFailure()
<< "action @" << action.getName()
<< " has results associated with multiple payload entities, "
"but flattening was not requested";
return WalkResult::interrupt();
}
break;
}
return WalkResult::advance();
Expand All @@ -1096,9 +1135,21 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
// by actions, are invalidated.
results.set(llvm::cast<OpResult>(getUpdated()),
state.getPayloadOps(getRoot()));
for (auto &&[result, mapping] :
llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) {
results.setMappedValues(result, mapping);
}
return overallDiag;
}

void transform::ForeachMatchOp::getAsmResultNames(
OpAsmSetValueNameFn setNameFn) {
setNameFn(getUpdated(), "updated_root");
for (Value v : getForwardedOutputs()) {
setNameFn(v, "yielded");
}
}

void transform::ForeachMatchOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
// Bail if invalid.
Expand All @@ -1108,7 +1159,8 @@ void transform::ForeachMatchOp::getEffects(
}

consumesHandle(getRoot(), effects);
producesHandle(getUpdated(), effects);
onlyReadsHandle(getForwardedInputs(), effects);
producesHandle(getResults(), effects);
modifiesPayload(effects);
}

Expand Down Expand Up @@ -1224,6 +1276,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName);
for (auto &&[matcher, action] :
llvm::zip_equal(getMatchers(), getActions())) {
// Presence and typing.
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
symbolTable.lookupNearestSymbolFrom(getOperation(),
cast<SymbolRefAttr>(matcher)));
Expand All @@ -1250,8 +1303,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(
return failure();
}

ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes();
ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes();
// Input -> matcher forwarding.
TypeRange operandTypes = getOperandTypes();
TypeRange matcherArguments = matcherSymbol.getArgumentTypes();
if (operandTypes.size() != matcherArguments.size()) {
InFlightDiagnostic diag =
emitError() << "the number of operands (" << operandTypes.size()
<< ") doesn't match the number of matcher arguments ("
<< matcherArguments.size() << ") for " << matcher;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}
for (auto &&[i, operand, argument] :
llvm::enumerate(operandTypes, matcherArguments)) {
if (matcherSymbol.getArgAttr(i, consumedAttr)) {
InFlightDiagnostic diag =
emitOpError()
<< "does not expect matcher symbol to consume its operand #" << i;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}

if (implementSameTransformInterface(operand, argument))
continue;

InFlightDiagnostic diag =
emitError()
<< "mismatching type interfaces for operand and matcher argument #"
<< i << " of matcher " << matcher;
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}

// Matcher -> action forwarding.
TypeRange matcherResults = matcherSymbol.getResultTypes();
TypeRange actionArguments = actionSymbol.getArgumentTypes();
if (matcherResults.size() != actionArguments.size()) {
return emitError() << "mismatching number of matcher results and "
"action arguments between "
Expand All @@ -1265,31 +1351,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses(

return emitError() << "mismatching type interfaces for matcher result "
"and action argument #"
<< i;
<< i << "of matcher " << matcher << " and action "
<< action;
}

if (!actionSymbol.getResultTypes().empty()) {
// Action -> result forwarding.
TypeRange actionResults = actionSymbol.getResultTypes();
auto resultTypes = TypeRange(getResultTypes()).drop_front();
if (actionResults.size() != resultTypes.size()) {
InFlightDiagnostic diag =
emitError() << "action symbol is not expected to have results";
emitError() << "the number of action results ("
<< actionResults.size() << ") for " << action
<< " doesn't match the number of extra op results ("
<< resultTypes.size() << ")";
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
return diag;
}
for (auto &&[i, resultType, actionType] :
llvm::enumerate(resultTypes, actionResults)) {
if (implementSameTransformInterface(resultType, actionType))
continue;

if (matcherSymbol.getArgumentTypes().size() != 1 ||
!implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0],
getRoot().getType())) {
InFlightDiagnostic diag =
emitOpError() << "expects matcher symbol to have one argument with "
"the same transform interface as the first operand";
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
return diag;
}

if (matcherSymbol.getArgAttr(0, consumedAttr)) {
InFlightDiagnostic diag =
emitOpError()
<< "does not expect matcher symbol to consume its operand";
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration";
emitError() << "mismatching type interfaces for action result #" << i
<< " of action " << action << " and op result";
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration";
return diag;
}
}
Expand Down
Loading