15
15
16
16
#include " MatchFinder.h"
17
17
#include " MatchersInternal.h"
18
+ #include " mlir/IR/Region.h"
19
+ #include " mlir/Query/Query.h"
20
+ #include " llvm/Support/raw_ostream.h"
18
21
19
22
namespace mlir {
20
23
@@ -24,80 +27,161 @@ namespace extramatcher {
24
27
25
28
namespace detail {
26
29
27
- class DefinitionsMatcher {
30
+ class BackwardSliceMatcher {
28
31
public:
29
- DefinitionsMatcher (matcher::DynMatcher &&InnerMatcher , unsigned Hops )
30
- : InnerMatcher (std::move(InnerMatcher )), Hops(Hops ) {}
32
+ BackwardSliceMatcher (matcher::DynMatcher &&innerMatcher , unsigned hops )
33
+ : innerMatcher (std::move(innerMatcher )), hops(hops ) {}
31
34
32
35
private:
33
- bool matches (Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
34
- unsigned TempHops) {
35
-
36
- llvm::DenseSet<mlir::Value> Ccache;
37
- llvm::SmallVector<std::pair<Operation *, size_t >, 4 > TempStorage;
38
- TempStorage.push_back ({op, TempHops});
39
- while (!TempStorage.empty ()) {
40
- auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val ();
41
-
42
- matcher::BoundOperationNode *CurrentNode =
43
- Bound.addNode (CurrentOp, true , true );
44
- if (RemainingHops == 0 ) {
45
- continue ;
46
- }
36
+ bool matches (Operation *op, SetVector<Operation *> &backwardSlice,
37
+ QueryOptions &options, unsigned tempHops) {
47
38
48
- for (auto Operand : CurrentOp->getOperands ()) {
49
- if (auto DefiningOp = Operand.getDefiningOp ()) {
50
- Bound.addEdge (CurrentOp, DefiningOp);
51
- if (!Ccache.contains (Operand)) {
52
- Ccache.insert (Operand);
53
- TempStorage.emplace_back (DefiningOp, RemainingHops - 1 );
54
- }
55
- } else if (auto BlockArg = Operand.dyn_cast <BlockArgument>()) {
56
- auto *Block = BlockArg.getOwner ();
39
+ bool validSlice = true ;
40
+ if (op->hasTrait <OpTrait::IsIsolatedFromAbove>()) {
41
+ return false ;
42
+ }
57
43
58
- if (Block->isEntryBlock () &&
59
- isa<FunctionOpInterface>(Block->getParentOp ())) {
60
- continue ;
44
+ auto processValue = [&](Value value) {
45
+ if (tempHops == 0 ) {
46
+ return ;
47
+ }
48
+ if (auto *definingOp = value.getDefiningOp ()) {
49
+ if (backwardSlice.count (definingOp) == 0 )
50
+ matches (definingOp, backwardSlice, options, tempHops - 1 );
51
+ } else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
52
+ if (options.omitBlockArguments )
53
+ return ;
54
+ Block *block = blockArg.getOwner ();
55
+
56
+ Operation *parentOp = block->getParentOp ();
57
+
58
+ if (parentOp && backwardSlice.count (parentOp) == 0 ) {
59
+ if (parentOp->getNumRegions () == 1 &&
60
+ parentOp->getRegion (0 ).getBlocks ().size () == 1 ) {
61
+ validSlice = false ;
62
+ return ;
63
+ };
64
+ matches (parentOp, backwardSlice, options, tempHops - 1 );
65
+ }
66
+ } else {
67
+ validSlice = false ;
68
+ return ;
69
+ }
70
+ };
71
+
72
+ if (!options.omitUsesFromAbove ) {
73
+ llvm::for_each (op->getRegions (), [&](Region ®ion) {
74
+ SmallPtrSet<Region *, 4 > descendents;
75
+ region.walk (
76
+ [&](Region *childRegion) { descendents.insert (childRegion); });
77
+ region.walk ([&](Operation *op) {
78
+ for (OpOperand &operand : op->getOpOperands ()) {
79
+ if (!descendents.contains (operand.get ().getParentRegion ()))
80
+ processValue (operand.get ());
81
+ if (!validSlice)
82
+ return ;
61
83
}
84
+ });
85
+ });
86
+ }
62
87
63
- Operation *ParentOp = BlockArg.getOwner ()->getParentOp ();
64
- if (ParentOp) {
65
- Bound.addEdge (CurrentOp, ParentOp);
66
- if (!!Ccache.contains (BlockArg)) {
67
- Ccache.insert (BlockArg);
68
- TempStorage.emplace_back (ParentOp, RemainingHops - 1 );
69
- }
70
- }
71
- }
88
+ llvm::for_each (op->getOperands (), [&](Value operand) {
89
+ processValue (operand);
90
+ if (!validSlice)
91
+ return ;
92
+ });
93
+ backwardSlice.insert (op);
94
+ if (!validSlice) {
95
+ return false ;
96
+ }
97
+ return true ;
98
+ }
99
+
100
+ public:
101
+ bool match (Operation *op, SetVector<Operation *> &backwardSlice,
102
+ QueryOptions &options) {
103
+ if (innerMatcher.match (op) && matches (op, backwardSlice, options, hops)) {
104
+ if (!options.inclusive ) {
105
+ backwardSlice.remove (op);
72
106
}
107
+ return true ;
73
108
}
74
- // We need at least 1 defining op
75
- return Ccache.size () >= 2 ;
109
+ return false ;
76
110
}
77
111
112
+ private:
113
+ matcher::DynMatcher innerMatcher;
114
+ unsigned hops;
115
+ };
116
+
117
+ class ForwardSliceMatcher {
78
118
public:
79
- bool match (Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
80
- if (InnerMatcher.match (op) && matches (op, Bound, Hops)) {
119
+ ForwardSliceMatcher (matcher::DynMatcher &&innerMatcher, unsigned hops)
120
+ : innerMatcher(std::move(innerMatcher)), hops(hops) {}
121
+
122
+ private:
123
+ bool matches (Operation *op, SetVector<Operation *> &forwardSlice,
124
+ QueryOptions &options, unsigned tempHops) {
125
+
126
+ if (tempHops == 0 ) {
127
+ forwardSlice.insert (op);
128
+ return true ;
129
+ }
130
+
131
+ for (Region ®ion : op->getRegions ())
132
+ for (Block &block : region)
133
+ for (Operation &blockOp : block)
134
+ if (forwardSlice.count (&blockOp) == 0 )
135
+ matches (&blockOp, forwardSlice, options, tempHops - 1 );
136
+ for (Value result : op->getResults ()) {
137
+ for (Operation *userOp : result.getUsers ())
138
+ if (forwardSlice.count (userOp) == 0 )
139
+ matches (userOp, forwardSlice, options, tempHops - 1 );
140
+ }
141
+
142
+ forwardSlice.insert (op);
143
+ return true ;
144
+ }
145
+
146
+ public:
147
+ bool match (Operation *op, SetVector<Operation *> &forwardSlice,
148
+ QueryOptions &options) {
149
+ if (innerMatcher.match (op) && matches (op, forwardSlice, options, hops)) {
150
+ if (!options.inclusive ) {
151
+ forwardSlice.remove (op);
152
+ }
153
+ SmallVector<Operation *, 0 > v (forwardSlice.takeVector ());
154
+ forwardSlice.insert (v.rbegin (), v.rend ());
81
155
return true ;
82
156
}
83
157
return false ;
84
158
}
85
159
86
160
private:
87
- matcher::DynMatcher InnerMatcher ;
88
- unsigned Hops ;
161
+ matcher::DynMatcher innerMatcher ;
162
+ unsigned hops ;
89
163
};
164
+
90
165
} // namespace detail
91
166
92
- inline detail::DefinitionsMatcher
93
- definedBy (mlir::query::matcher::DynMatcher InnerMatcher) {
94
- return detail::DefinitionsMatcher (std::move (InnerMatcher), 1 );
167
+ inline detail::BackwardSliceMatcher
168
+ definedBy (mlir::query::matcher::DynMatcher innerMatcher) {
169
+ return detail::BackwardSliceMatcher (std::move (innerMatcher), 1 );
170
+ }
171
+
172
+ inline detail::BackwardSliceMatcher
173
+ getDefinitions (mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
174
+ return detail::BackwardSliceMatcher (std::move (innerMatcher), hops);
175
+ }
176
+
177
+ inline detail::ForwardSliceMatcher
178
+ usedBy (mlir::query::matcher::DynMatcher innerMatcher) {
179
+ return detail::ForwardSliceMatcher (std::move (innerMatcher), 1 );
95
180
}
96
181
97
- inline detail::DefinitionsMatcher
98
- getDefinitions (mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
99
- assert (Hops > 0 && " hops must be >= 1" );
100
- return detail::DefinitionsMatcher (std::move (InnerMatcher), Hops);
182
+ inline detail::ForwardSliceMatcher
183
+ getUses (mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
184
+ return detail::ForwardSliceMatcher (std::move (innerMatcher), hops);
101
185
}
102
186
103
187
} // namespace extramatcher
0 commit comments