|
22 | 22 | #include "mlir/IR/Verifier.h"
|
23 | 23 | #include "mlir/Interfaces/ControlFlowInterfaces.h"
|
24 | 24 | #include "mlir/Interfaces/FunctionImplementation.h"
|
| 25 | +#include "mlir/Interfaces/FunctionInterfaces.h" |
25 | 26 | #include "mlir/Pass/Pass.h"
|
26 | 27 | #include "mlir/Pass/PassManager.h"
|
27 | 28 | #include "mlir/Pass/PassRegistry.h"
|
@@ -783,7 +784,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
|
783 | 784 | }
|
784 | 785 |
|
785 | 786 | //===----------------------------------------------------------------------===//
|
786 |
| -// ForeachMatchOp |
| 787 | +// CollectMatchingOp |
787 | 788 | //===----------------------------------------------------------------------===//
|
788 | 789 |
|
789 | 790 | /// Applies matcher operations from the given `block` assigning `op` as the
|
@@ -822,6 +823,137 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state,
|
822 | 823 | return DiagnosedSilenceableFailure::success();
|
823 | 824 | }
|
824 | 825 |
|
| 826 | +/// Returns `true` if both types implement one of the interfaces provided as |
| 827 | +/// template parameters. |
| 828 | +template <typename... Tys> |
| 829 | +static bool implementSameInterface(Type t1, Type t2) { |
| 830 | + return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); |
| 831 | +} |
| 832 | + |
| 833 | +/// Returns `true` if both types implement one of the transform dialect |
| 834 | +/// interfaces. |
| 835 | +static bool implementSameTransformInterface(Type t1, Type t2) { |
| 836 | + return implementSameInterface<transform::TransformHandleTypeInterface, |
| 837 | + transform::TransformParamTypeInterface, |
| 838 | + transform::TransformValueHandleTypeInterface>( |
| 839 | + t1, t2); |
| 840 | +} |
| 841 | + |
| 842 | +//===----------------------------------------------------------------------===// |
| 843 | +// CollectMatchingOp |
| 844 | +//===----------------------------------------------------------------------===// |
| 845 | + |
| 846 | +DiagnosedSilenceableFailure |
| 847 | +transform::CollectMatchingOp::apply(transform::TransformRewriter &rewriter, |
| 848 | + transform::TransformResults &results, |
| 849 | + transform::TransformState &state) { |
| 850 | + auto matcher = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>( |
| 851 | + getOperation(), getMatcher()); |
| 852 | + if (matcher.isExternal()) { |
| 853 | + return emitDefiniteFailure() |
| 854 | + << "unresolved external symbol " << getMatcher(); |
| 855 | + } |
| 856 | + |
| 857 | + SmallVector<SmallVector<MappedValue>, 2> rawResults; |
| 858 | + rawResults.resize(getOperation()->getNumResults()); |
| 859 | + std::optional<DiagnosedSilenceableFailure> maybeFailure; |
| 860 | + for (Operation *root : state.getPayloadOps(getRoot())) { |
| 861 | + WalkResult walkResult = root->walk([&](Operation *op) { |
| 862 | + DEBUG_MATCHER({ |
| 863 | + DBGS_MATCHER() << "matching "; |
| 864 | + op->print(llvm::dbgs(), |
| 865 | + OpPrintingFlags().assumeVerified().skipRegions()); |
| 866 | + llvm::dbgs() << " @" << op << "\n"; |
| 867 | + }); |
| 868 | + |
| 869 | + // Try matching. |
| 870 | + SmallVector<SmallVector<MappedValue>> mappings; |
| 871 | + DiagnosedSilenceableFailure diag = |
| 872 | + matchBlock(matcher.getFunctionBody().front(), op, state, mappings); |
| 873 | + if (diag.isDefiniteFailure()) |
| 874 | + return WalkResult::interrupt(); |
| 875 | + if (diag.isSilenceableFailure()) { |
| 876 | + DEBUG_MATCHER(DBGS_MATCHER() << "matcher " << matcher.getName() |
| 877 | + << " failed: " << diag.getMessage()); |
| 878 | + return WalkResult::advance(); |
| 879 | + } |
| 880 | + |
| 881 | + // If succeeded, collect results. |
| 882 | + for (auto &&[i, mapping] : llvm::enumerate(mappings)) { |
| 883 | + if (mapping.size() != 1) { |
| 884 | + maybeFailure.emplace(emitSilenceableError() |
| 885 | + << "result #" << i << ", associated with " |
| 886 | + << mapping.size() |
| 887 | + << " payload objects, expected 1"); |
| 888 | + return WalkResult::interrupt(); |
| 889 | + } |
| 890 | + rawResults[i].push_back(mapping[0]); |
| 891 | + } |
| 892 | + return WalkResult::advance(); |
| 893 | + }); |
| 894 | + if (walkResult.wasInterrupted()) |
| 895 | + return std::move(*maybeFailure); |
| 896 | + assert(!maybeFailure && "failure set but the walk was not interrupted"); |
| 897 | + |
| 898 | + for (auto &&[opResult, rawResult] : |
| 899 | + llvm::zip_equal(getOperation()->getResults(), rawResults)) { |
| 900 | + results.setMappedValues(opResult, rawResult); |
| 901 | + } |
| 902 | + } |
| 903 | + return DiagnosedSilenceableFailure::success(); |
| 904 | +} |
| 905 | + |
| 906 | +void transform::CollectMatchingOp::getEffects( |
| 907 | + SmallVectorImpl<MemoryEffects::EffectInstance> &effects) { |
| 908 | + onlyReadsHandle(getRoot(), effects); |
| 909 | + producesHandle(getResults(), effects); |
| 910 | + onlyReadsPayload(effects); |
| 911 | +} |
| 912 | + |
| 913 | +LogicalResult transform::CollectMatchingOp::verifySymbolUses( |
| 914 | + SymbolTableCollection &symbolTable) { |
| 915 | + auto matcherSymbol = dyn_cast_or_null<FunctionOpInterface>( |
| 916 | + symbolTable.lookupNearestSymbolFrom(getOperation(), getMatcher())); |
| 917 | + if (!matcherSymbol || |
| 918 | + !isa<TransformOpInterface>(matcherSymbol.getOperation())) |
| 919 | + return emitError() << "unresolved matcher symbol " << getMatcher(); |
| 920 | + |
| 921 | + ArrayRef<Type> argumentTypes = matcherSymbol.getArgumentTypes(); |
| 922 | + if (argumentTypes.size() != 1 || |
| 923 | + !isa<TransformHandleTypeInterface>(argumentTypes[0])) { |
| 924 | + return emitError() |
| 925 | + << "expected the matcher to take one operation handle argument"; |
| 926 | + } |
| 927 | + if (!matcherSymbol.getArgAttr( |
| 928 | + 0, transform::TransformDialect::kArgReadOnlyAttrName)) { |
| 929 | + return emitError() << "expected the matcher argument to be marked readonly"; |
| 930 | + } |
| 931 | + |
| 932 | + ArrayRef<Type> resultTypes = matcherSymbol.getResultTypes(); |
| 933 | + if (resultTypes.size() != getOperation()->getNumResults()) { |
| 934 | + return emitError() |
| 935 | + << "expected the matcher to yield as many values as op has results (" |
| 936 | + << getOperation()->getNumResults() << "), got " |
| 937 | + << resultTypes.size(); |
| 938 | + } |
| 939 | + |
| 940 | + for (auto &&[i, matcherType, resultType] : |
| 941 | + llvm::enumerate(resultTypes, getOperation()->getResultTypes())) { |
| 942 | + if (implementSameTransformInterface(matcherType, resultType)) |
| 943 | + continue; |
| 944 | + |
| 945 | + return emitError() |
| 946 | + << "mismatching type interfaces for matcher result and op result #" |
| 947 | + << i; |
| 948 | + } |
| 949 | + |
| 950 | + return success(); |
| 951 | +} |
| 952 | + |
| 953 | +//===----------------------------------------------------------------------===// |
| 954 | +// ForeachMatchOp |
| 955 | +//===----------------------------------------------------------------------===// |
| 956 | + |
825 | 957 | DiagnosedSilenceableFailure
|
826 | 958 | transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter,
|
827 | 959 | transform::TransformResults &results,
|
@@ -978,22 +1110,6 @@ LogicalResult transform::ForeachMatchOp::verify() {
|
978 | 1110 | return success();
|
979 | 1111 | }
|
980 | 1112 |
|
981 |
| -/// Returns `true` if both types implement one of the interfaces provided as |
982 |
| -/// template parameters. |
983 |
| -template <typename... Tys> |
984 |
| -static bool implementSameInterface(Type t1, Type t2) { |
985 |
| - return ((isa<Tys>(t1) && isa<Tys>(t2)) || ... || false); |
986 |
| -} |
987 |
| - |
988 |
| -/// Returns `true` if both types implement one of the transform dialect |
989 |
| -/// interfaces. |
990 |
| -static bool implementSameTransformInterface(Type t1, Type t2) { |
991 |
| - return implementSameInterface<transform::TransformHandleTypeInterface, |
992 |
| - transform::TransformParamTypeInterface, |
993 |
| - transform::TransformValueHandleTypeInterface>( |
994 |
| - t1, t2); |
995 |
| -} |
996 |
| - |
997 | 1113 | /// Checks that the attributes of the function-like operation have correct
|
998 | 1114 | /// consumption effect annotations. If `alsoVerifyInternal`, checks for
|
999 | 1115 | /// annotations being present even if they can be inferred from the body.
|
|
0 commit comments