1
1
// ===- MatchersInternal.h - Structural query framework ----------*- C++ -*-===//
2
2
//
3
- // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3
+ // Part of the LLVM Project, under the Apache License v2.0 wIth LLVM Exceptions.
4
4
// See https://llvm.org/LICENSE.txt for license information.
5
- // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
- //
7
- // ===----------------------------------------------------------------------===//
8
- //
9
- // Implements the base layer of the matcher framework.
10
- //
11
- // Matchers are methods that return a Matcher which provides a method
12
- // match(Operation *op)
13
- //
14
- // The matcher functions are defined in include/mlir/IR/Matchers.h.
15
- // This file contains the wrapper classes needed to construct matchers for
16
- // mlir-query.
5
+ // SPDX-License-Identifier: Apache-2.0 WItH LLVM-exception
17
6
//
18
7
// ===----------------------------------------------------------------------===//
19
8
22
11
23
12
#include " mlir/IR/Matchers.h"
24
13
#include " llvm/ADT/IntrusiveRefCntPtr.h"
14
+ #include " llvm/ADT/MapVector.h"
15
+ #include < memory>
16
+ #include < stack>
17
+ #include < unordered_set>
18
+ #include < vector>
25
19
26
20
namespace mlir ::query::matcher {
27
21
22
+ struct BoundOperationNode {
23
+ Operation *op;
24
+ std::vector<BoundOperationNode *> Parents;
25
+ std::vector<BoundOperationNode *> Children;
26
+
27
+ bool IsRootNode;
28
+ bool DetailedPrinting;
29
+
30
+ BoundOperationNode (Operation *op, bool IsRootNode = false ,
31
+ bool DetailedPrinting = false )
32
+ : op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
33
+ };
34
+
35
+ class BoundOperationsGraphBuilder {
36
+ public:
37
+ BoundOperationNode *addNode (Operation *op, bool IsRootNode = false ,
38
+ bool DetailedPrinting = false ) {
39
+ auto It = Nodes.find (op);
40
+ if (It != Nodes.end ()) {
41
+ return It->second .get ();
42
+ }
43
+ auto Node =
44
+ std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
45
+ BoundOperationNode *NodePtr = Node.get ();
46
+ Nodes[op] = std::move (Node);
47
+ return NodePtr;
48
+ }
49
+
50
+ void addEdge (Operation *parentOp, Operation *childOp) {
51
+ BoundOperationNode *ParentNode = addNode (parentOp, false , false );
52
+ BoundOperationNode *ChildNode = addNode (childOp, false , false );
53
+
54
+ ParentNode->Children .push_back (ChildNode);
55
+ ChildNode->Parents .push_back (ParentNode);
56
+ }
57
+
58
+ BoundOperationNode *getNode (Operation *op) const {
59
+ auto It = Nodes.find (op);
60
+ return It != Nodes.end () ? It->second .get () : nullptr ;
61
+ }
62
+
63
+ const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
64
+ getNodes () const {
65
+ return Nodes;
66
+ }
67
+
68
+ private:
69
+ llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
70
+ };
71
+
72
+ // Type traIt to detect if a matcher has a match(Operation*) method
73
+ template <typename T, typename = void >
74
+ struct has_simple_match : std::false_type {};
75
+
76
+ template <typename T>
77
+ struct has_simple_match <T, std::void_t <decltype (std::declval<T>().match(
78
+ std::declval<Operation *>()))>>
79
+ : std::true_type {};
80
+
81
+ // Type traIt to detect if a matcher has a match(Operation*,
82
+ // BoundOperationsGraphBuilder&) method
83
+ template <typename T, typename = void >
84
+ struct has_bound_match : std::false_type {};
85
+
86
+ template <typename T>
87
+ struct has_bound_match <T, std::void_t <decltype (std::declval<T>().match(
88
+ std::declval<Operation *>(),
89
+ std::declval<BoundOperationsGraphBuilder &>()))>>
90
+ : std::true_type {};
91
+
28
92
// Generic interface for matchers on an MLIR operation.
29
93
class MatcherInterface
30
94
: public llvm::ThreadSafeRefCountedBase<MatcherInterface> {
31
95
public:
32
96
virtual ~MatcherInterface () = default ;
33
-
34
97
virtual bool match (Operation *op) = 0;
98
+ virtual bool match (Operation *op, BoundOperationsGraphBuilder &bound) = 0;
35
99
};
36
100
37
101
// MatcherFnImpl takes a matcher function object and implements
@@ -40,40 +104,56 @@ template <typename MatcherFn>
40
104
class MatcherFnImpl : public MatcherInterface {
41
105
public:
42
106
MatcherFnImpl (MatcherFn &matcherFn) : matcherFn(matcherFn) {}
43
- bool match (Operation *op) override { return matcherFn.match (op); }
107
+
108
+ bool match (Operation *op) override {
109
+ if constexpr (has_simple_match<MatcherFn>::value)
110
+ return matcherFn.match (op);
111
+ return false ;
112
+ }
113
+
114
+ bool match (Operation *op, BoundOperationsGraphBuilder &bound) override {
115
+ if constexpr (has_bound_match<MatcherFn>::value)
116
+ return matcherFn.match (op, bound);
117
+ return false ;
118
+ }
44
119
45
120
private:
46
121
MatcherFn matcherFn;
47
122
};
48
123
49
- // Matcher wraps a MatcherInterface implementation and provides a match()
50
- // method that redirects calls to the underlying implementation.
124
+ // Matcher wraps a MatcherInterface implementation and provides match()
125
+ // methods that redirect calls to the underlying implementation.
51
126
class DynMatcher {
52
127
public:
53
128
// Takes ownership of the provided implementation pointer.
54
- DynMatcher (MatcherInterface *implementation)
55
- : implementation(implementation) {}
129
+ DynMatcher (MatcherInterface *implementation, StringRef matcherName )
130
+ : implementation(implementation), matcherName(matcherName.str()) {}
56
131
57
132
template <typename MatcherFn>
58
133
static std::unique_ptr<DynMatcher>
59
- constructDynMatcherFromMatcherFn (MatcherFn &matcherFn) {
134
+ constructDynMatcherFromMatcherFn (MatcherFn &matcherFn,
135
+ StringRef matcherName) {
60
136
auto impl = std::make_unique<MatcherFnImpl<MatcherFn>>(matcherFn);
61
- return std::make_unique<DynMatcher>(impl.release ());
137
+ return std::make_unique<DynMatcher>(impl.release (), matcherName );
62
138
}
63
139
64
140
bool match (Operation *op) const { return implementation->match (op); }
141
+ bool match (Operation *op, BoundOperationsGraphBuilder &bound) const {
142
+ return implementation->match (op, bound);
143
+ }
65
144
66
- void setFunctionName (StringRef name) { functionName = name.str (); };
67
-
68
- bool hasFunctionName () const { return !functionName.empty (); };
69
-
70
- StringRef getFunctionName () const { return functionName ; };
145
+ void setFunctionName (StringRef name) { functionName = name.str (); }
146
+ void setMatcherName (StringRef name) { matcherName = name. str (); }
147
+ bool hasFunctionName () const { return !functionName.empty (); }
148
+ StringRef getFunctionName () const { return functionName; }
149
+ StringRef getMatcherName () const { return matcherName ; }
71
150
72
151
private:
73
152
llvm::IntrusiveRefCntPtr<MatcherInterface> implementation;
153
+ std::string matcherName;
74
154
std::string functionName;
75
155
};
76
156
77
157
} // namespace mlir::query::matcher
78
158
79
- #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
159
+ #endif // MLIR_TOOLS_MLIRQUERY_MATCHER_MATCHERSINTERNAL_H
0 commit comments