-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] make transform.foreach_match forward arguments #89920
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
#include "mlir/IR/BuiltinAttributes.h" | ||
#include "mlir/IR/Diagnostics.h" | ||
#include "mlir/IR/Dominance.h" | ||
#include "mlir/IR/OpImplementation.h" | ||
#include "mlir/IR/OperationSupport.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/IR/Verifier.h" | ||
|
@@ -834,19 +835,23 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { | |
// CollectMatchingOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
/// Applies matcher operations from the given `block` assigning `op` as the | ||
/// payload of the block's first argument. Updates `state` accordingly. If any | ||
/// of the matcher produces a silenceable failure, discards it (printing the | ||
/// content to the debug output stream) and returns failure. If any of the | ||
/// matchers produces a definite failure, reports it and returns failure. If all | ||
/// matchers in the block succeed, populates `mappings` with the payload | ||
/// entities associated with the block terminator operands. | ||
/// Applies matcher operations from the given `block` using | ||
/// `blockArgumentMapping` to initialize block arguments. Updates `state` | ||
/// accordingly. If any of the matcher produces a silenceable failure, discards | ||
/// it (printing the content to the debug output stream) and returns failure. If | ||
/// any of the matchers produces a definite failure, reports it and returns | ||
/// failure. If all matchers in the block succeed, populates `mappings` with the | ||
/// payload entities associated with the block terminator operands. Note that | ||
/// `mappings` will be cleared before that. | ||
static DiagnosedSilenceableFailure | ||
matchBlock(Block &block, Operation *op, transform::TransformState &state, | ||
matchBlock(Block &block, | ||
ArrayRef<SmallVector<transform::MappedValue>> blockArgumentMapping, | ||
transform::TransformState &state, | ||
SmallVectorImpl<SmallVector<transform::MappedValue>> &mappings) { | ||
assert(block.getParent() && "cannot match using a detached block"); | ||
auto matchScope = state.make_region_scope(*block.getParent()); | ||
if (failed(state.mapBlockArgument(block.getArgument(0), {op}))) | ||
if (failed( | ||
state.mapBlockArguments(block.getArguments(), blockArgumentMapping))) | ||
return DiagnosedSilenceableFailure::definiteFailure(); | ||
|
||
for (Operation &match : block.without_terminator()) { | ||
|
@@ -866,6 +871,9 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state, | |
// Remember the values mapped to the terminator operands so we can | ||
// forward them to the action. | ||
ValueRange yieldedValues = block.getTerminator()->getOperands(); | ||
// Our contract with the caller is that the mappings will contain only the | ||
// newly mapped values, clear the rest. | ||
mappings.clear(); | ||
transform::detail::prepareValueMappings(mappings, yieldedValues, state); | ||
return DiagnosedSilenceableFailure::success(); | ||
} | ||
|
@@ -915,8 +923,11 @@ transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, | |
|
||
// Try matching. | ||
SmallVector<SmallVector<MappedValue>> mappings; | ||
DiagnosedSilenceableFailure diag = | ||
matchBlock(matcher.getFunctionBody().front(), op, state, mappings); | ||
SmallVector<transform::MappedValue> inputMapping({op}); | ||
DiagnosedSilenceableFailure diag = matchBlock( | ||
matcher.getFunctionBody().front(), | ||
ArrayRef<SmallVector<transform::MappedValue>>(inputMapping), state, | ||
mappings); | ||
if (diag.isDefiniteFailure()) | ||
return WalkResult::interrupt(); | ||
if (diag.isSilenceableFailure()) { | ||
|
@@ -1001,6 +1012,9 @@ LogicalResult transform::CollectMatchingOp::verifySymbolUses( | |
// ForeachMatchOp | ||
//===----------------------------------------------------------------------===// | ||
|
||
// This is fine because nothing is actually consumed by this op. | ||
bool transform::ForeachMatchOp::allowsRepeatedHandleOperands() { return true; } | ||
|
||
DiagnosedSilenceableFailure | ||
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, | ||
transform::TransformResults &results, | ||
|
@@ -1030,6 +1044,18 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, | |
|
||
DiagnosedSilenceableFailure overallDiag = | ||
DiagnosedSilenceableFailure::success(); | ||
|
||
SmallVector<SmallVector<MappedValue>> matchInputMapping; | ||
SmallVector<SmallVector<MappedValue>> matchOutputMapping; | ||
SmallVector<SmallVector<MappedValue>> actionResultMapping; | ||
// Explicitly add the mapping for the first block argument (the op being | ||
// matched). | ||
matchInputMapping.emplace_back(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need this outide of the call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because we also need to map the op that is being matched before the payloads that are forwarded. Added a comment. |
||
transform::detail::prepareValueMappings(matchInputMapping, | ||
getForwardedInputs(), state); | ||
SmallVector<MappedValue> &firstMatchArgument = matchInputMapping.front(); | ||
actionResultMapping.resize(getForwardedOutputs().size()); | ||
|
||
for (Operation *root : state.getPayloadOps(getRoot())) { | ||
WalkResult walkResult = root->walk([&](Operation *op) { | ||
// If getRestrictRoot is not present, skip over the root op itself so we | ||
|
@@ -1044,11 +1070,14 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, | |
llvm::dbgs() << " @" << op << "\n"; | ||
}); | ||
|
||
firstMatchArgument.clear(); | ||
firstMatchArgument.push_back(op); | ||
|
||
// Try all the match/action pairs until the first successful match. | ||
for (auto [matcher, action] : matchActionPairs) { | ||
SmallVector<SmallVector<MappedValue>> mappings; | ||
DiagnosedSilenceableFailure diag = | ||
matchBlock(matcher.getFunctionBody().front(), op, state, mappings); | ||
matchBlock(matcher.getFunctionBody().front(), matchInputMapping, | ||
state, matchOutputMapping); | ||
if (diag.isDefiniteFailure()) | ||
return WalkResult::interrupt(); | ||
if (diag.isSilenceableFailure()) { | ||
|
@@ -1058,10 +1087,10 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, | |
} | ||
|
||
auto scope = state.make_region_scope(action.getFunctionBody()); | ||
for (auto &&[arg, map] : llvm::zip_equal( | ||
action.getFunctionBody().front().getArguments(), mappings)) { | ||
if (failed(state.mapBlockArgument(arg, map))) | ||
return WalkResult::interrupt(); | ||
if (failed(state.mapBlockArguments( | ||
action.getFunctionBody().front().getArguments(), | ||
matchOutputMapping))) { | ||
return WalkResult::interrupt(); | ||
} | ||
|
||
for (Operation &transform : | ||
|
@@ -1082,6 +1111,16 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, | |
continue; | ||
} | ||
} | ||
if (failed(detail::appendValueMappings( | ||
MutableArrayRef<SmallVector<MappedValue>>(actionResultMapping), | ||
action.getFunctionBody().front().getTerminator()->getOperands(), | ||
state, getFlattenResults()))) { | ||
emitDefiniteFailure() | ||
<< "action @" << action.getName() | ||
<< " has results associated with multiple payload entities, " | ||
"but flattening was not requested"; | ||
return WalkResult::interrupt(); | ||
} | ||
break; | ||
} | ||
return WalkResult::advance(); | ||
|
@@ -1096,9 +1135,21 @@ transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, | |
// by actions, are invalidated. | ||
results.set(llvm::cast<OpResult>(getUpdated()), | ||
state.getPayloadOps(getRoot())); | ||
for (auto &&[result, mapping] : | ||
llvm::zip_equal(getForwardedOutputs(), actionResultMapping)) { | ||
results.setMappedValues(result, mapping); | ||
} | ||
return overallDiag; | ||
} | ||
|
||
void transform::ForeachMatchOp::getAsmResultNames( | ||
OpAsmSetValueNameFn setNameFn) { | ||
setNameFn(getUpdated(), "updated_root"); | ||
for (Value v : getForwardedOutputs()) { | ||
setNameFn(v, "yielded"); | ||
} | ||
} | ||
|
||
void transform::ForeachMatchOp::getEffects( | ||
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { | ||
// Bail if invalid. | ||
|
@@ -1108,7 +1159,8 @@ void transform::ForeachMatchOp::getEffects( | |
} | ||
|
||
consumesHandle(getRoot(), effects); | ||
producesHandle(getUpdated(), effects); | ||
onlyReadsHandle(getForwardedInputs(), effects); | ||
producesHandle(getResults(), effects); | ||
modifiesPayload(effects); | ||
} | ||
|
||
|
@@ -1224,6 +1276,7 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses( | |
StringAttr::get(getContext(), TransformDialect::kArgConsumedAttrName); | ||
for (auto &&[matcher, action] : | ||
llvm::zip_equal(getMatchers(), getActions())) { | ||
// Presence and typing. | ||
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( | ||
symbolTable.lookupNearestSymbolFrom(getOperation(), | ||
cast<SymbolRefAttr>(matcher))); | ||
|
@@ -1250,8 +1303,41 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses( | |
return failure(); | ||
} | ||
|
||
ArrayRef<Type> matcherResults = matcherSymbol.getResultTypes(); | ||
ArrayRef<Type> actionArguments = actionSymbol.getArgumentTypes(); | ||
// Input -> matcher forwarding. | ||
TypeRange operandTypes = getOperandTypes(); | ||
TypeRange matcherArguments = matcherSymbol.getArgumentTypes(); | ||
if (operandTypes.size() != matcherArguments.size()) { | ||
InFlightDiagnostic diag = | ||
emitError() << "the number of operands (" << operandTypes.size() | ||
<< ") doesn't match the number of matcher arguments (" | ||
<< matcherArguments.size() << ") for " << matcher; | ||
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; | ||
return diag; | ||
} | ||
for (auto &&[i, operand, argument] : | ||
llvm::enumerate(operandTypes, matcherArguments)) { | ||
if (matcherSymbol.getArgAttr(i, consumedAttr)) { | ||
InFlightDiagnostic diag = | ||
emitOpError() | ||
<< "does not expect matcher symbol to consume its operand #" << i; | ||
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; | ||
return diag; | ||
} | ||
|
||
if (implementSameTransformInterface(operand, argument)) | ||
continue; | ||
|
||
InFlightDiagnostic diag = | ||
emitError() | ||
<< "mismatching type interfaces for operand and matcher argument #" | ||
<< i << " of matcher " << matcher; | ||
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; | ||
return diag; | ||
} | ||
|
||
// Matcher -> action forwarding. | ||
TypeRange matcherResults = matcherSymbol.getResultTypes(); | ||
TypeRange actionArguments = actionSymbol.getArgumentTypes(); | ||
if (matcherResults.size() != actionArguments.size()) { | ||
return emitError() << "mismatching number of matcher results and " | ||
"action arguments between " | ||
|
@@ -1265,31 +1351,31 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses( | |
|
||
return emitError() << "mismatching type interfaces for matcher result " | ||
"and action argument #" | ||
<< i; | ||
<< i << "of matcher " << matcher << " and action " | ||
<< action; | ||
} | ||
|
||
if (!actionSymbol.getResultTypes().empty()) { | ||
// Action -> result forwarding. | ||
TypeRange actionResults = actionSymbol.getResultTypes(); | ||
auto resultTypes = TypeRange(getResultTypes()).drop_front(); | ||
if (actionResults.size() != resultTypes.size()) { | ||
InFlightDiagnostic diag = | ||
emitError() << "action symbol is not expected to have results"; | ||
emitError() << "the number of action results (" | ||
<< actionResults.size() << ") for " << action | ||
<< " doesn't match the number of extra op results (" | ||
<< resultTypes.size() << ")"; | ||
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; | ||
return diag; | ||
} | ||
for (auto &&[i, resultType, actionType] : | ||
llvm::enumerate(resultTypes, actionResults)) { | ||
if (implementSameTransformInterface(resultType, actionType)) | ||
continue; | ||
|
||
if (matcherSymbol.getArgumentTypes().size() != 1 || | ||
!implementSameTransformInterface(matcherSymbol.getArgumentTypes()[0], | ||
getRoot().getType())) { | ||
InFlightDiagnostic diag = | ||
emitOpError() << "expects matcher symbol to have one argument with " | ||
"the same transform interface as the first operand"; | ||
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; | ||
return diag; | ||
} | ||
|
||
if (matcherSymbol.getArgAttr(0, consumedAttr)) { | ||
InFlightDiagnostic diag = | ||
emitOpError() | ||
<< "does not expect matcher symbol to consume its operand"; | ||
diag.attachNote(matcherSymbol->getLoc()) << "symbol declaration"; | ||
emitError() << "mismatching type interfaces for action result #" << i | ||
<< " of action " << action << " and op result"; | ||
diag.attachNote(actionSymbol->getLoc()) << "symbol declaration"; | ||
return diag; | ||
} | ||
} | ||
|
Uh oh!
There was an error while loading. Please reload this page.