Skip to content

Commit cb380e7

Browse files
committed
MLIR-QUERY DefinitionsMatcher implementation & DAG
- included printing logic for DAG - sfinae for match methods
1 parent b5df0e7 commit cb380e7

File tree

10 files changed

+463
-57
lines changed

10 files changed

+463
-57
lines changed
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
//===- Matchers.h - Various common matchers ---------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file provides extra matchers that are very useful for mlir-query
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_IR_EXTRAMATCHERS_H
14+
#define MLIR_IR_EXTRAMATCHERS_H
15+
16+
#include "MatchFinder.h"
17+
#include "MatchersInternal.h"
18+
19+
namespace mlir {
20+
21+
namespace query {
22+
23+
namespace extramatcher {
24+
25+
namespace detail {
26+
27+
class DefinitionsMatcher {
28+
public:
29+
DefinitionsMatcher(matcher::DynMatcher &&InnerMatcher, unsigned Hops)
30+
: InnerMatcher(std::move(InnerMatcher)), Hops(Hops) {}
31+
32+
private:
33+
bool matches(Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
34+
unsigned TempHops) {
35+
36+
llvm::DenseSet<mlir::Value> Ccache;
37+
llvm::SmallVector<std::pair<Operation *, size_t>, 4> TempStorage;
38+
TempStorage.push_back({op, TempHops});
39+
while (!TempStorage.empty()) {
40+
auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val();
41+
42+
matcher::BoundOperationNode *CurrentNode =
43+
Bound.addNode(CurrentOp, true, true);
44+
if (RemainingHops == 0) {
45+
continue;
46+
}
47+
48+
for (auto Operand : CurrentOp->getOperands()) {
49+
if (auto DefiningOp = Operand.getDefiningOp()) {
50+
Bound.addEdge(CurrentOp, DefiningOp);
51+
if (!Ccache.contains(Operand)) {
52+
Ccache.insert(Operand);
53+
TempStorage.emplace_back(DefiningOp, RemainingHops - 1);
54+
}
55+
} else if (auto BlockArg = Operand.dyn_cast<BlockArgument>()) {
56+
auto *Block = BlockArg.getOwner();
57+
58+
if (Block->isEntryBlock() &&
59+
isa<FunctionOpInterface>(Block->getParentOp())) {
60+
continue;
61+
}
62+
63+
Operation *ParentOp = BlockArg.getOwner()->getParentOp();
64+
if (ParentOp) {
65+
Bound.addEdge(CurrentOp, ParentOp);
66+
if (!!Ccache.contains(BlockArg)) {
67+
Ccache.insert(BlockArg);
68+
TempStorage.emplace_back(ParentOp, RemainingHops - 1);
69+
}
70+
}
71+
}
72+
}
73+
}
74+
// We need at least 1 defining op
75+
return Ccache.size() >= 2;
76+
}
77+
78+
public:
79+
bool match(Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
80+
if (InnerMatcher.match(op) && matches(op, Bound, Hops)) {
81+
return true;
82+
}
83+
return false;
84+
}
85+
86+
private:
87+
matcher::DynMatcher InnerMatcher;
88+
unsigned Hops;
89+
};
90+
} // namespace detail
91+
92+
inline detail::DefinitionsMatcher
93+
definedBy(mlir::query::matcher::DynMatcher InnerMatcher) {
94+
return detail::DefinitionsMatcher(std::move(InnerMatcher), 1);
95+
}
96+
97+
inline detail::DefinitionsMatcher
98+
getDefinitions(mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
99+
assert(Hops > 0 && "hops must be >= 1");
100+
return detail::DefinitionsMatcher(std::move(InnerMatcher), Hops);
101+
}
102+
103+
} // namespace extramatcher
104+
105+
} // namespace query
106+
107+
} // namespace mlir
108+
109+
#endif // MLIR_IR_EXTRAMATCHERS_H

mlir/include/mlir/Query/Matcher/Marshallers.h

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ struct ArgTypeTraits<llvm::StringRef> {
5050
}
5151
};
5252

53+
template <>
54+
struct ArgTypeTraits<unsigned> {
55+
static bool hasCorrectType(const VariantValue &value) {
56+
return value.isUnsigned();
57+
}
58+
59+
static unsigned get(const VariantValue &value) { return value.getUnsigned(); }
60+
61+
static ArgKind getKind() { return ArgKind::Unsigned; }
62+
63+
static std::optional<std::string> getBestGuess(const VariantValue &) {
64+
return std::nullopt;
65+
}
66+
};
67+
5368
template <>
5469
struct ArgTypeTraits<DynMatcher> {
5570

@@ -166,7 +181,7 @@ matcherMarshallFixedImpl(void (*matcherFunc)(), llvm::StringRef matcherName,
166181
ReturnType fnPointer = reinterpret_cast<FuncType>(matcherFunc)(
167182
ArgTypeTraits<ArgTypes>::get(args[Is].value)...);
168183
return VariantMatcher::SingleMatcher(
169-
*DynMatcher::constructDynMatcherFromMatcherFn(fnPointer));
184+
*DynMatcher::constructDynMatcherFromMatcherFn(fnPointer, matcherName));
170185
}
171186

172187
return VariantMatcher();

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,26 @@
1515
#define MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERFINDER_H
1616

1717
#include "MatchersInternal.h"
18+
#include "mlir/IR/Operation.h"
1819

1920
namespace mlir::query::matcher {
2021

2122
// MatchFinder is used to find all operations that match a given matcher.
2223
class MatchFinder {
2324
public:
2425
// Returns all operations that match the given matcher.
25-
static std::vector<Operation *> getMatches(Operation *root,
26-
DynMatcher matcher) {
27-
std::vector<Operation *> matches;
26+
static BoundOperationsGraphBuilder getMatches(Operation *root,
27+
DynMatcher matcher) {
2828

29-
// Simple match finding with walk.
29+
BoundOperationsGraphBuilder Bound;
3030
root->walk([&](Operation *subOp) {
31-
if (matcher.match(subOp))
32-
matches.push_back(subOp);
31+
if (matcher.match(subOp)) {
32+
matcher::BoundOperationNode *currentNode = Bound.addNode(subOp);
33+
} else if (matcher.match(subOp, Bound)) {
34+
////
35+
}
3336
});
34-
35-
return matches;
37+
return Bound;
3638
}
3739
};
3840

Lines changed: 107 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,8 @@
11
//===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
22
//
3-
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3+
// Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
5-
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
//
7-
//===----------------------------------------------------------------------===//
8-
//
9-
// Implements the base layer of the matcher framework.
10-
//
11-
// Matchers are methods that return a Matcher which provides a method
12-
// match(Operation *op)
13-
//
14-
// The matcher functions are defined in include/mlir/IR/Matchers.h.
15-
// This file contains the wrapper classes needed to construct matchers for
16-
// mlir-query.
5+
// SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
176
//
187
//===----------------------------------------------------------------------===//
198

@@ -22,16 +11,91 @@
2211

2312
#include "mlir/IR/Matchers.h"
2413
#include "llvm/ADT/IntrusiveRefCntPtr.h"
14+
#include "llvm/ADT/MapVector.h"
15+
#include <memory>
16+
#include <stack>
17+
#include <unordered_set>
18+
#include <vector>
2519

2620
namespace mlir::query::matcher {
2721

22+
struct BoundOperationNode {
23+
Operation *op;
24+
std::vector<BoundOperationNode *> Parents;
25+
std::vector<BoundOperationNode *> Children;
26+
27+
bool IsRootNode;
28+
bool DetailedPrinting;
29+
30+
BoundOperationNode(Operation *op, bool IsRootNode = false,
31+
bool DetailedPrinting = false)
32+
: op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
33+
};
34+
35+
class BoundOperationsGraphBuilder {
36+
public:
37+
BoundOperationNode *addNode(Operation *op, bool IsRootNode = false,
38+
bool DetailedPrinting = false) {
39+
auto It = Nodes.find(op);
40+
if (It != Nodes.end()) {
41+
return It->second.get();
42+
}
43+
auto Node =
44+
std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
45+
BoundOperationNode *NodePtr = Node.get();
46+
Nodes[op] = std::move(Node);
47+
return NodePtr;
48+
}
49+
50+
void addEdge(Operation *parentOp, Operation *childOp) {
51+
BoundOperationNode *ParentNode = addNode(parentOp, false, false);
52+
BoundOperationNode *ChildNode = addNode(childOp, false, false);
53+
54+
ParentNode->Children.push_back(ChildNode);
55+
ChildNode->Parents.push_back(ParentNode);
56+
}
57+
58+
BoundOperationNode *getNode(Operation *op) const {
59+
auto It = Nodes.find(op);
60+
return It != Nodes.end() ? It->second.get() : nullptr;
61+
}
62+
63+
const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
64+
getNodes() const {
65+
return Nodes;
66+
}
67+
68+
private:
69+
llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
70+
};
71+
72+
// Type traIt to detect if a matcher has a match(Operation*) method
73+
template <typename T, typename = void>
74+
struct has_simple_match : std::false_type {};
75+
76+
template <typename T>
77+
struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
78+
std::declval<Operation *>()))>>
79+
: std::true_type {};
80+
81+
// Type traIt to detect if a matcher has a match(Operation*,
82+
// BoundOperationsGraphBuilder&) method
83+
template <typename T, typename = void>
84+
struct has_bound_match : std::false_type {};
85+
86+
template <typename T>
87+
struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
88+
std::declval<Operation *>(),
89+
std::declval<BoundOperationsGraphBuilder &>()))>>
90+
: std::true_type {};
91+
2892
// Generic interface for matchers on an MLIR operation.
2993
class MatcherInterface
3094
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
3195
public:
3296
virtual ~MatcherInterface() = default;
33-
3497
virtual bool match(Operation *op) = 0;
98+
virtual bool match(Operation *op, BoundOperationsGraphBuilder &bound) = 0;
3599
};
36100

37101
// MatcherFnImpl takes a matcher function object and implements
@@ -40,40 +104,56 @@ template <typename MatcherFn>
40104
class MatcherFnImpl : public MatcherInterface {
41105
public:
42106
MatcherFnImpl(MatcherFn &matcherFn) : matcherFn(matcherFn) {}
43-
bool match(Operation *op) override { return matcherFn.match(op); }
107+
108+
bool match(Operation *op) override {
109+
if constexpr (has_simple_match<MatcherFn>::value)
110+
return matcherFn.match(op);
111+
return false;
112+
}
113+
114+
bool match(Operation *op, BoundOperationsGraphBuilder &bound) override {
115+
if constexpr (has_bound_match<MatcherFn>::value)
116+
return matcherFn.match(op, bound);
117+
return false;
118+
}
44119

45120
private:
46121
MatcherFn matcherFn;
47122
};
48123

49-
// Matcher wraps a MatcherInterface implementation and provides a match()
50-
// method that redirects calls to the underlying implementation.
124+
// Matcher wraps a MatcherInterface implementation and provides match()
125+
// methods that redirect calls to the underlying implementation.
51126
class DynMatcher {
52127
public:
53128
// Takes ownership of the provided implementation pointer.
54-
DynMatcher(MatcherInterface *implementation)
55-
: implementation(implementation) {}
129+
DynMatcher(MatcherInterface *implementation, StringRef matcherName)
130+
: implementation(implementation), matcherName(matcherName.str()) {}
56131

57132
template <typename MatcherFn>
58133
static std::unique_ptr<DynMatcher>
59-
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn) {
134+
constructDynMatcherFromMatcherFn(MatcherFn &matcherFn,
135+
StringRef matcherName) {
60136
auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
61-
return std::make_unique<DynMatcher>(impl.release());
137+
return std::make_unique<DynMatcher>(impl.release(), matcherName);
62138
}
63139

64140
bool match(Operation *op) const { return implementation->match(op); }
141+
bool match(Operation *op, BoundOperationsGraphBuilder &bound) const {
142+
return implementation->match(op, bound);
143+
}
65144

66-
void setFunctionName(StringRef name) { functionName = name.str(); };
67-
68-
bool hasFunctionName() const { return !functionName.empty(); };
69-
70-
StringRef getFunctionName() const { return functionName; };
145+
void setFunctionName(StringRef name) { functionName = name.str(); }
146+
void setMatcherName(StringRef name) { matcherName = name.str(); }
147+
bool hasFunctionName() const { return !functionName.empty(); }
148+
StringRef getFunctionName() const { return functionName; }
149+
StringRef getMatcherName() const { return matcherName; }
71150

72151
private:
73152
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
153+
std::string matcherName;
74154
std::string functionName;
75155
};
76156

77157
} // namespace mlir::query::matcher
78158

79-
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
159+
#endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H

0 commit comments

Comments
 (0)