Skip to content

Commit c0aa1f0

Browse files
authored
Revert "[mlir] Improve mlir-query by adding matcher combinators (#141423)"
This reverts commit 12611a7.
1 parent 4359e55 commit c0aa1f0

15 files changed

+17
-471
lines changed

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,6 @@ class MatcherDescriptor {
108108
const llvm::ArrayRef<ParserValue> args,
109109
Diagnostics *error) const = 0;
110110

111-
// If the matcher is variadic, it can take any number of arguments.
112-
virtual bool isVariadic() const = 0;
113-
114111
// Returns the number of arguments accepted by the matcher.
115112
virtual unsigned getNumArgs() const = 0;
116113

@@ -143,8 +140,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
143140
return marshaller(matcherFunc, matcherName, nameRange, args, error);
144141
}
145142

146-
bool isVariadic() const override { return false; }
147-
148143
unsigned getNumArgs() const override { return argKinds.size(); }
149144

150145
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
@@ -158,54 +153,6 @@ class FixedArgCountMatcherDescriptor : public MatcherDescriptor {
158153
const std::vector<ArgKind> argKinds;
159154
};
160155

161-
class VariadicOperatorMatcherDescriptor : public MatcherDescriptor {
162-
public:
163-
using VarOp = DynMatcher::VariadicOperator;
164-
VariadicOperatorMatcherDescriptor(unsigned minCount, unsigned maxCount,
165-
VarOp varOp, StringRef matcherName)
166-
: minCount(minCount), maxCount(maxCount), varOp(varOp),
167-
matcherName(matcherName) {}
168-
169-
VariantMatcher create(SourceRange nameRange, ArrayRef<ParserValue> args,
170-
Diagnostics *error) const override {
171-
if (args.size() < minCount || maxCount < args.size()) {
172-
addError(error, nameRange, ErrorType::RegistryWrongArgCount,
173-
{llvm::Twine("requires between "), llvm::Twine(minCount),
174-
llvm::Twine(" and "), llvm::Twine(maxCount),
175-
llvm::Twine(" args, got "), llvm::Twine(args.size())});
176-
return VariantMatcher();
177-
}
178-
179-
std::vector<VariantMatcher> innerArgs;
180-
for (int64_t i = 0, e = args.size(); i != e; ++i) {
181-
const ParserValue &arg = args[i];
182-
const VariantValue &value = arg.value;
183-
if (!value.isMatcher()) {
184-
addError(error, arg.range, ErrorType::RegistryWrongArgType,
185-
{llvm::Twine(i + 1), llvm::Twine("matcher: "),
186-
llvm::Twine(value.getTypeAsString())});
187-
return VariantMatcher();
188-
}
189-
innerArgs.push_back(value.getMatcher());
190-
}
191-
return VariantMatcher::VariadicOperatorMatcher(varOp, std::move(innerArgs));
192-
}
193-
194-
bool isVariadic() const override { return true; }
195-
196-
unsigned getNumArgs() const override { return 0; }
197-
198-
void getArgKinds(unsigned argNo, std::vector<ArgKind> &kinds) const override {
199-
kinds.push_back(ArgKind(ArgKind::Matcher));
200-
}
201-
202-
private:
203-
const unsigned minCount;
204-
const unsigned maxCount;
205-
const VarOp varOp;
206-
const StringRef matcherName;
207-
};
208-
209156
// Helper function to check if argument count matches expected count
210157
inline bool checkArgCount(SourceRange nameRange, size_t expectedArgCount,
211158
llvm::ArrayRef<ParserValue> args,
@@ -277,14 +224,6 @@ makeMatcherAutoMarshall(ReturnType (*matcherFunc)(ArgTypes...),
277224
reinterpret_cast<void (*)()>(matcherFunc), matcherName, argKinds);
278225
}
279226

280-
// Variadic operator overload.
281-
template <unsigned MinCount, unsigned MaxCount>
282-
std::unique_ptr<MatcherDescriptor>
283-
makeMatcherAutoMarshall(VariadicOperatorMatcherFunc<MinCount, MaxCount> func,
284-
StringRef matcherName) {
285-
return std::make_unique<VariadicOperatorMatcherDescriptor>(
286-
MinCount, MaxCount, func.varOp, matcherName);
287-
}
288227
} // namespace mlir::query::matcher::internal
289228

290229
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MARSHALLERS_H

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121

2222
namespace mlir::query::matcher {
2323

24-
/// Finds and collects matches from the IR. After construction
25-
/// `collectMatches` can be used to traverse the IR and apply
26-
/// matchers.
24+
/// A class that provides utilities to find operations in the IR.
2725
class MatchFinder {
2826

2927
public:

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 4 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
//
99
// Implements the base layer of the matcher framework.
1010
//
11-
// Matchers are methods that return a Matcher which provides a
12-
// `match(...)` method whose parameters define the context of the match.
13-
// Support includes simple (unary) matchers as well as matcher combinators
14-
// (anyOf, allOf, etc.)
11+
// Matchers are methods that return a Matcher which provides a method one of the
12+
// following methods: match(Operation *op), match(Operation *op,
13+
// SetVector<Operation *> &matchedOps)
1514
//
15+
// The matcher functions are defined in include/mlir/IR/Matchers.h.
1616
// This file contains the wrapper classes needed to construct matchers for
1717
// mlir-query.
1818
//
@@ -25,15 +25,6 @@
2525
#include "llvm/ADT/IntrusiveRefCntPtr.h"
2626

2727
namespace mlir::query::matcher {
28-
class DynMatcher;
29-
namespace internal {
30-
31-
bool allOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
32-
ArrayRef<DynMatcher> innerMatchers);
33-
bool anyOfVariadicOperator(Operation *op, SetVector<Operation *> *matchedOps,
34-
ArrayRef<DynMatcher> innerMatchers);
35-
36-
} // namespace internal
3728

3829
// Defaults to false if T has no match() method with the signature:
3930
// match(Operation* op).
@@ -93,27 +84,6 @@ class MatcherFnImpl : public MatcherInterface {
9384
MatcherFn matcherFn;
9485
};
9586

96-
// VariadicMatcher takes a vector of Matchers and returns true if any Matchers
97-
// match the given operation.
98-
using VariadicOperatorFunction = bool (*)(Operation *op,
99-
SetVector<Operation *> *matchedOps,
100-
ArrayRef<DynMatcher> innerMatchers);
101-
102-
template <VariadicOperatorFunction Func>
103-
class VariadicMatcher : public MatcherInterface {
104-
public:
105-
VariadicMatcher(std::vector<DynMatcher> matchers)
106-
: matchers(std::move(matchers)) {}
107-
108-
bool match(Operation *op) override { return Func(op, nullptr, matchers); }
109-
bool match(Operation *op, SetVector<Operation *> &matchedOps) override {
110-
return Func(op, &matchedOps, matchers);
111-
}
112-
113-
private:
114-
std::vector<DynMatcher> matchers;
115-
};
116-
11787
// Matcher wraps a MatcherInterface implementation and provides match()
11888
// methods that redirect calls to the underlying implementation.
11989
class DynMatcher {
@@ -122,31 +92,6 @@ class DynMatcher {
12292
DynMatcher(MatcherInterface *implementation)
12393
: implementation(implementation) {}
12494

125-
// Construct from a variadic function.
126-
enum VariadicOperator {
127-
// Matches operations for which all provided matchers match.
128-
AllOf,
129-
// Matches operations for which at least one of the provided matchers
130-
// matches.
131-
AnyOf
132-
};
133-
134-
static std::unique_ptr<DynMatcher>
135-
constructVariadic(VariadicOperator Op,
136-
std::vector<DynMatcher> innerMatchers) {
137-
switch (Op) {
138-
case AllOf:
139-
return std::make_unique<DynMatcher>(
140-
new VariadicMatcher<internal::allOfVariadicOperator>(
141-
std::move(innerMatchers)));
142-
case AnyOf:
143-
return std::make_unique<DynMatcher>(
144-
new VariadicMatcher<internal::anyOfVariadicOperator>(
145-
std::move(innerMatchers)));
146-
}
147-
llvm_unreachable("Invalid Op value.");
148-
}
149-
15095
template <typename MatcherFn>
15196
static std::unique_ptr<DynMatcher>
15297
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
@@ -168,59 +113,6 @@ class DynMatcher {
168113
std::string functionName;
169114
};
170115

171-
// VariadicOperatorMatcher related types.
172-
template <typename... Ps>
173-
class VariadicOperatorMatcher {
174-
public:
175-
VariadicOperatorMatcher(DynMatcher::VariadicOperator varOp, Ps &&...params)
176-
: varOp(varOp), params(std::forward<Ps>(params)...) {}
177-
178-
operator std::unique_ptr<DynMatcher>() const & {
179-
return DynMatcher::constructVariadic(
180-
varOp, getMatchers(std::index_sequence_for<Ps...>()));
181-
}
182-
183-
operator std::unique_ptr<DynMatcher>() && {
184-
return DynMatcher::constructVariadic(
185-
varOp, std::move(*this).getMatchers(std::index_sequence_for<Ps...>()));
186-
}
187-
188-
private:
189-
// Helper method to unpack the tuple into a vector.
190-
template <std::size_t... Is>
191-
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) const & {
192-
return {DynMatcher(std::get<Is>(params))...};
193-
}
194-
195-
template <std::size_t... Is>
196-
std::vector<DynMatcher> getMatchers(std::index_sequence<Is...>) && {
197-
return {DynMatcher(std::get<Is>(std::move(params)))...};
198-
}
199-
200-
const DynMatcher::VariadicOperator varOp;
201-
std::tuple<Ps...> params;
202-
};
203-
204-
// Overloaded function object to generate VariadicOperatorMatcher objects from
205-
// arbitrary matchers.
206-
template <unsigned MinCount, unsigned MaxCount>
207-
struct VariadicOperatorMatcherFunc {
208-
DynMatcher::VariadicOperator varOp;
209-
210-
template <typename... Ms>
211-
VariadicOperatorMatcher<Ms...> operator()(Ms &&...Ps) const {
212-
static_assert(MinCount <= sizeof...(Ms) && sizeof...(Ms) <= MaxCount,
213-
"invalid number of parameters for variadic matcher");
214-
return VariadicOperatorMatcher<Ms...>(varOp, std::forward<Ms>(Ps)...);
215-
}
216-
};
217-
218-
namespace internal {
219-
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
220-
anyOf = {DynMatcher::AnyOf};
221-
const VariadicOperatorMatcherFunc<1, std::numeric_limits<unsigned>::max()>
222-
allOf = {DynMatcher::AllOf};
223-
} // namespace internal
224116
} // namespace mlir::query::matcher
225117

226118
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

mlir/include/mlir/Query/Matcher/SliceMatchers.h

Lines changed: 5 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// This file defines slicing-analysis matchers that extend and abstract the
10-
// core implementations from `SliceAnalysis.h`.
9+
// This file provides matchers for MLIRQuery that peform slicing analysis
1110
//
1211
//===----------------------------------------------------------------------===//
1312

@@ -17,9 +16,9 @@
1716
#include "mlir/Analysis/SliceAnalysis.h"
1817
#include "mlir/IR/Operation.h"
1918

20-
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
21-
/// if `innerMatcher` matches. The traversal stops once the desired depth level
22-
/// is reached.
19+
/// A matcher encapsulating `getBackwardSlice` method from SliceAnalysis.h.
20+
/// Additionally, it limits the slice computation to a certain depth level using
21+
/// a custom filter.
2322
///
2423
/// Example: starting from node 9, assuming the matcher
2524
/// computes the slice for the first two depth levels:
@@ -120,77 +119,6 @@ bool BackwardSliceMatcher<Matcher>::matches(
120119
: backwardSlice.size() >= 1;
121120
}
122121

123-
/// Computes the backward-slice of all transitive defs reachable from `rootOp`,
124-
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
125-
template <typename BaseMatcher, typename Filter>
126-
class PredicateBackwardSliceMatcher {
127-
public:
128-
PredicateBackwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
129-
bool inclusive, bool omitBlockArguments,
130-
bool omitUsesFromAbove)
131-
: innerMatcher(std::move(innerMatcher)),
132-
filterMatcher(std::move(filterMatcher)), inclusive(inclusive),
133-
omitBlockArguments(omitBlockArguments),
134-
omitUsesFromAbove(omitUsesFromAbove) {}
135-
136-
bool match(Operation *rootOp, SetVector<Operation *> &backwardSlice) {
137-
backwardSlice.clear();
138-
BackwardSliceOptions options;
139-
options.inclusive = inclusive;
140-
options.omitUsesFromAbove = omitUsesFromAbove;
141-
options.omitBlockArguments = omitBlockArguments;
142-
if (innerMatcher.match(rootOp)) {
143-
options.filter = [&](Operation *subOp) {
144-
return !filterMatcher.match(subOp);
145-
};
146-
LogicalResult result = getBackwardSlice(rootOp, &backwardSlice, options);
147-
assert(result.succeeded() && "expected backward slice to succeed");
148-
(void)result;
149-
return options.inclusive ? backwardSlice.size() > 1
150-
: backwardSlice.size() >= 1;
151-
}
152-
return false;
153-
}
154-
155-
private:
156-
BaseMatcher innerMatcher;
157-
Filter filterMatcher;
158-
bool inclusive;
159-
bool omitBlockArguments;
160-
bool omitUsesFromAbove;
161-
};
162-
163-
/// Computes the forward-slice of all users reachable from `rootOp`,
164-
/// if `innerMatcher` matches. Traversal stops where `filterMatcher` matches.
165-
template <typename BaseMatcher, typename Filter>
166-
class PredicateForwardSliceMatcher {
167-
public:
168-
PredicateForwardSliceMatcher(BaseMatcher innerMatcher, Filter filterMatcher,
169-
bool inclusive)
170-
: innerMatcher(std::move(innerMatcher)),
171-
filterMatcher(std::move(filterMatcher)), inclusive(inclusive) {}
172-
173-
bool match(Operation *rootOp, SetVector<Operation *> &forwardSlice) {
174-
forwardSlice.clear();
175-
ForwardSliceOptions options;
176-
options.inclusive = inclusive;
177-
if (innerMatcher.match(rootOp)) {
178-
options.filter = [&](Operation *subOp) {
179-
return !filterMatcher.match(subOp);
180-
};
181-
getForwardSlice(rootOp, &forwardSlice, options);
182-
return options.inclusive ? forwardSlice.size() > 1
183-
: forwardSlice.size() >= 1;
184-
}
185-
return false;
186-
}
187-
188-
private:
189-
BaseMatcher innerMatcher;
190-
Filter filterMatcher;
191-
bool inclusive;
192-
};
193-
194122
/// Matches transitive defs of a top-level operation up to N levels.
195123
template <typename Matcher>
196124
inline BackwardSliceMatcher<Matcher>
@@ -202,7 +130,7 @@ m_GetDefinitions(Matcher innerMatcher, int64_t maxDepth, bool inclusive,
202130
omitUsesFromAbove);
203131
}
204132

205-
/// Matches all transitive defs of a top-level operation up to N levels.
133+
/// Matches all transitive defs of a top-level operation up to N levels
206134
template <typename Matcher>
207135
inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
208136
int64_t maxDepth) {
@@ -211,28 +139,6 @@ inline BackwardSliceMatcher<Matcher> m_GetAllDefinitions(Matcher innerMatcher,
211139
false, false);
212140
}
213141

214-
/// Matches all transitive defs of a top-level operation and stops where
215-
/// `filterMatcher` rejects.
216-
template <typename BaseMatcher, typename Filter>
217-
inline PredicateBackwardSliceMatcher<BaseMatcher, Filter>
218-
m_GetDefinitionsByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
219-
bool inclusive, bool omitBlockArguments,
220-
bool omitUsesFromAbove) {
221-
return PredicateBackwardSliceMatcher<BaseMatcher, Filter>(
222-
std::move(innerMatcher), std::move(filterMatcher), inclusive,
223-
omitBlockArguments, omitUsesFromAbove);
224-
}
225-
226-
/// Matches all users of a top-level operation and stops where
227-
/// `filterMatcher` rejects.
228-
template <typename BaseMatcher, typename Filter>
229-
inline PredicateForwardSliceMatcher<BaseMatcher, Filter>
230-
m_GetUsersByPredicate(BaseMatcher innerMatcher, Filter filterMatcher,
231-
bool inclusive) {
232-
return PredicateForwardSliceMatcher<BaseMatcher, Filter>(
233-
std::move(innerMatcher), std::move(filterMatcher), inclusive);
234-
}
235-
236142
} // namespace mlir::query::matcher
237143

238144
#endif // MLIR_TOOLS_MLIRQUERY_MATCHERS_SLICEMATCHERS_H

0 commit comments

Comments
 (0)