Skip to content

Commit 633d918

Browse files
authored
[mlir] introduce transform.collect_matching (#76724)
Introduce a new match combinator into the transform dialect. This operation collects all operations that are yielded by a satisfactory match into its results. This is a simpler version of `foreach_match` that can be inserted directly into existing transform scripts.
1 parent 4f7c402 commit 633d918

File tree

4 files changed

+279
-18
lines changed

4 files changed

+279
-18
lines changed

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,39 @@ def NumAssociationsOp : TransformDialectOp<"num_associations",
460460
let hasVerifier = 1;
461461
}
462462

463+
def CollectMatchingOp : TransformDialectOp<"collect_matching", [
464+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
465+
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
466+
DeclareOpInterfaceMethods<TransformOpInterface>]> {
467+
let summary = "Collects all payload ops that match the given named matcher";
468+
let description = [{
469+
Collects operations or other payload IR objects nested under `root`
470+
(inclusive) that match the given matcher expressed as a named sequence. The
471+
matcher sequence must accept exactly one argument that it is not allowed to
472+
modify. It must yield as many values as this op has results. Each of the
473+
yielded values must be associated with exactly one payload object. If any
474+
operation in the matcher sequence produces a silenceable failure, the
475+
matcher advances to the next payload operation in the walk order without
476+
finishing the sequence.
477+
478+
The i-th result of this operation is constructed by concatenating the i-th
479+
yielded payload IR objects of all successful matcher sequence applications.
480+
All results are guaranteed to be mapped to the same number of payload IR
481+
objects.
482+
483+
The operation succeeds unless the matcher sequence produced a definite
484+
failure for any invocation.
485+
}];
486+
487+
let arguments = (ins TransformHandleTypeInterface:$root,
488+
SymbolRefAttr:$matcher);
489+
let results = (outs Variadic<Transform_AnyHandleOrParamType>:$results);
490+
491+
let assemblyFormat = [{
492+
$matcher `in` $root attr-dict `:` functional-type($root, $results)
493+
}];
494+
}
495+
463496
def ForeachMatchOp : TransformDialectOp<"foreach_match", [
464497
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
465498
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
@@ -674,7 +707,7 @@ def GetParentOp : TransformDialectOp<"get_parent_op",
674707

675708
def GetProducerOfOperand : TransformDialectOp<"get_producer_of_operand",
676709
[DeclareOpInterfaceMethods<TransformOpInterface>,
677-
NavigationTransformOpTrait, MemoryEffectsOpInterface]> {
710+
NavigationTransformOpTrait, MatchOpInterface, MemoryEffectsOpInterface]> {
678711
let summary = "Get handle to the producer of this operation's operand number";
679712
let description = [{
680713
The handle defined by this Transform op corresponds to operation that

mlir/lib/Dialect/Transform/IR/TransformOps.cpp

Lines changed: 133 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/Verifier.h"
2323
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2424
#include "mlir/Interfaces/FunctionImplementation.h"
25+
#include "mlir/Interfaces/FunctionInterfaces.h"
2526
#include "mlir/Pass/Pass.h"
2627
#include "mlir/Pass/PassManager.h"
2728
#include "mlir/Pass/PassRegistry.h"
@@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
783784
}
784785

785786
//===----------------------------------------------------------------------===//
786-
// ForeachMatchOp
787+
// CollectMatchingOp
787788
//===----------------------------------------------------------------------===//
788789

789790
/// Applies matcher operations from the given `block` assigning `op` as the
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
822823
return DiagnosedSilenceableFailure::success();
823824
}
824825

826+
/// Returns `true` if both types implement one of the interfaces provided as
827+
/// template parameters.
828+
template <typename... Tys>
829+
static bool implementSameInterface(Type t1, Type t2) {
830+
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
831+
}
832+
833+
/// Returns `true` if both types implement one of the transform dialect
834+
/// interfaces.
835+
static bool implementSameTransformInterface(Type t1, Type t2) {
836+
return implementSameInterface<transform::TransformHandleTypeInterface,
837+
transform::TransformParamTypeInterface,
838+
transform::TransformValueHandleTypeInterface>(
839+
t1, t2);
840+
}
841+
842+
//===----------------------------------------------------------------------===//
843+
// CollectMatchingOp
844+
//===----------------------------------------------------------------------===//
845+
846+
DiagnosedSilenceableFailure
847+
transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter,
848+
transform::TransformResults &results,
849+
transform::TransformState &state) {
850+
auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
851+
getOperation(), getMatcher());
852+
if (matcher.isExternal()) {
853+
return emitDefiniteFailure()
854+
<< "unresolved external symbol " << getMatcher();
855+
}
856+
857+
SmallVector<SmallVector<MappedValue>, 2> rawResults;
858+
rawResults.resize(getOperation()->getNumResults());
859+
std::optional<DiagnosedSilenceableFailure> maybeFailure;
860+
for (Operation *root : state.getPayloadOps(getRoot())) {
861+
WalkResult walkResult = root->walk([&](Operation *op) {
862+
DEBUG_MATCHER({
863+
DBGS_MATCHER() << "matching ";
864+
op->print(llvm::dbgs(),
865+
OpPrintingFlags().assumeVerified().skipRegions());
866+
llvm::dbgs() << " @" << op << "\n";
867+
});
868+
869+
// Try matching.
870+
SmallVector<SmallVector<MappedValue>> mappings;
871+
DiagnosedSilenceableFailure diag =
872+
matchBlock(matcher.getFunctionBody().front(), op, state, mappings);
873+
if (diag.isDefiniteFailure())
874+
return WalkResult::interrupt();
875+
if (diag.isSilenceableFailure()) {
876+
DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName()
877+
<< " failed: " << diag.getMessage());
878+
return WalkResult::advance();
879+
}
880+
881+
// If succeeded, collect results.
882+
for (auto &&[i, mapping] : llvm::enumerate(mappings)) {
883+
if (mapping.size() != 1) {
884+
maybeFailure.emplace(emitSilenceableError()
885+
<< "result #" << i << ", associated with "
886+
<< mapping.size()
887+
<< " payload objects, expected 1");
888+
return WalkResult::interrupt();
889+
}
890+
rawResults[i].push_back(mapping[0]);
891+
}
892+
return WalkResult::advance();
893+
});
894+
if (walkResult.wasInterrupted())
895+
return std::move(*maybeFailure);
896+
assert(!maybeFailure && "failure set but the walk was not interrupted");
897+
898+
for (auto &&[opResult, rawResult] :
899+
llvm::zip_equal(getOperation()->getResults(), rawResults)) {
900+
results.setMappedValues(opResult, rawResult);
901+
}
902+
}
903+
return DiagnosedSilenceableFailure::success();
904+
}
905+
906+
void transform::CollectMatchingOp::getEffects(
907+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
908+
onlyReadsHandle(getRoot(), effects);
909+
producesHandle(getResults(), effects);
910+
onlyReadsPayload(effects);
911+
}
912+
913+
LogicalResult transform::CollectMatchingOp::verifySymbolUses(
914+
SymbolTableCollection &symbolTable) {
915+
auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>(
916+
symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher()));
917+
if (!matcherSymbol ||
918+
!isa<TransformOpInterface>(matcherSymbol.getOperation()))
919+
return emitError() << "unresolved matcher symbol " << getMatcher();
920+
921+
ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes();
922+
if (argumentTypes.size() != 1 ||
923+
!isa<TransformHandleTypeInterface>(argumentTypes[0])) {
924+
return emitError()
925+
<< "expected the matcher to take one operation handle argument";
926+
}
927+
if (!matcherSymbol.getArgAttr(
928+
0, transform::TransformDialect::kArgReadOnlyAttrName)) {
929+
return emitError() << "expected the matcher argument to be marked readonly";
930+
}
931+
932+
ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes();
933+
if (resultTypes.size() != getOperation()->getNumResults()) {
934+
return emitError()
935+
<< "expected the matcher to yield as many values as op has results ("
936+
<< getOperation()->getNumResults() << "), got "
937+
<< resultTypes.size();
938+
}
939+
940+
for (auto &&[i, matcherType, resultType] :
941+
llvm::enumerate(resultTypes, getOperation()->getResultTypes())) {
942+
if (implementSameTransformInterface(matcherType, resultType))
943+
continue;
944+
945+
return emitError()
946+
<< "mismatching type interfaces for matcher result and op result #"
947+
<< i;
948+
}
949+
950+
return success();
951+
}
952+
953+
//===----------------------------------------------------------------------===//
954+
// ForeachMatchOp
955+
//===----------------------------------------------------------------------===//
956+
825957
DiagnosedSilenceableFailure
826958
transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
827959
transform::TransformResults &results,
@@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
9781110
return success();
9791111
}
9801112

981-
/// Returns `true` if both types implement one of the interfaces provided as
982-
/// template parameters.
983-
template <typename... Tys>
984-
static bool implementSameInterface(Type t1, Type t2) {
985-
return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false);
986-
}
987-
988-
/// Returns `true` if both types implement one of the transform dialect
989-
/// interfaces.
990-
static bool implementSameTransformInterface(Type t1, Type t2) {
991-
return implementSameInterface<transform::TransformHandleTypeInterface,
992-
transform::TransformParamTypeInterface,
993-
transform::TransformValueHandleTypeInterface>(
994-
t1, t2);
995-
}
996-
9971113
/// Checks that the attributes of the function-like operation have correct
9981114
/// consumption effect annotations. If `alsoVerifyInternal`, checks for
9991115
/// annotations being present even if they can be inferred from the body.

mlir/test/Dialect/Transform/ops-invalid.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,3 +704,71 @@ transform.sequence failures(propagate) {
704704
// expected-error @below {{expected the type of the parameter attribute ('i64') to match the parameter type ('i32')}}
705705
transform.num_associations %arg0 : (!transform.any_op) -> !transform.param<i32>
706706
}
707+
708+
// -----
709+
710+
module attributes { transform.with_named_sequence } {
711+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
712+
// expected-error @below {{unresolved matcher symbol @missing_symbol}}
713+
transform.collect_matching @missing_symbol in %arg0 : (!transform.any_op) -> !transform.any_op
714+
transform.yield
715+
}
716+
}
717+
718+
// -----
719+
720+
module attributes { transform.with_named_sequence } {
721+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
722+
// expected-error @below {{expected the matcher to take one operation handle argument}}
723+
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
724+
transform.yield
725+
}
726+
727+
transform.named_sequence @matcher() {
728+
transform.yield
729+
}
730+
}
731+
732+
// -----
733+
734+
735+
module attributes { transform.with_named_sequence } {
736+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
737+
// expected-error @below {{expected the matcher argument to be marked readonly}}
738+
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
739+
transform.yield
740+
}
741+
742+
transform.named_sequence @matcher(%arg0: !transform.any_op) {
743+
transform.yield
744+
}
745+
}
746+
747+
748+
// -----
749+
750+
module attributes { transform.with_named_sequence } {
751+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
752+
// expected-error @below {{expected the matcher to yield as many values as op has results (1), got 0}}
753+
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
754+
transform.yield
755+
}
756+
757+
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) {
758+
transform.yield
759+
}
760+
}
761+
762+
// -----
763+
764+
module attributes { transform.with_named_sequence } {
765+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
766+
// expected-error @below {{mismatching type interfaces for matcher result and op result #0}}
767+
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_value
768+
transform.yield
769+
}
770+
771+
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
772+
transform.yield %arg0 : !transform.any_op
773+
}
774+
}

mlir/test/Dialect/Transform/test-interpreter.mlir

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,3 +2380,47 @@ module @named_inclusion attributes { transform.with_named_sequence } {
23802380
transform.yield
23812381
}
23822382
}
2383+
2384+
// -----
2385+
2386+
module attributes { transform.with_named_sequence } {
2387+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
2388+
// expected-error @below {{result #0, associated with 2 payload objects, expected 1}}
2389+
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
2390+
transform.yield
2391+
}
2392+
2393+
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
2394+
%0 = transform.merge_handles %arg0, %arg0 : !transform.any_op
2395+
transform.yield %0 : !transform.any_op
2396+
}
2397+
}
2398+
2399+
// -----
2400+
2401+
module attributes { transform.with_named_sequence } {
2402+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
2403+
// expected-error @below {{unresolved external symbol @matcher}}
2404+
transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
2405+
transform.yield
2406+
}
2407+
2408+
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
2409+
}
2410+
2411+
// -----
2412+
2413+
module attributes { transform.with_named_sequence } {
2414+
transform.named_sequence @__transform_main(%arg0: !transform.any_op) {
2415+
// expected-remark @below {{matched}}
2416+
%0 = transform.collect_matching @matcher in %arg0 : (!transform.any_op) -> !transform.any_op
2417+
// expected-remark @below {{matched}}
2418+
transform.test_print_remark_at_operand %0, "matched" : !transform.any_op
2419+
transform.yield
2420+
}
2421+
2422+
transform.named_sequence @matcher(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op {
2423+
transform.match.operation_name %arg0 ["transform.test_print_remark_at_operand", "transform.collect_matching"] : !transform.any_op
2424+
transform.yield %arg0 : !transform.any_op
2425+
}
2426+
}

0 commit comments

Comments
 (0)