Skip to content

Commit 23d38a7

Browse files
committed
Added SetQuery, LetQuery, new implementation for matchers
1 parent cb380e7 commit 23d38a7

File tree

15 files changed

+367
-243
lines changed

15 files changed

+367
-243
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,15 @@ struct NameOpMatcher {
5959
NameOpMatcher(StringRef name) : name(name) {}
6060
bool match(Operation *op) { return op->getName().getStringRef() == name; }
6161

62-
StringRef name;
62+
std::string name;
6363
};
6464

6565
/// The matcher that matches operations that have the specified attribute name.
6666
struct AttrOpMatcher {
6767
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
6868
bool match(Operation *op) { return op->hasAttr(attrName); }
6969

70-
StringRef attrName;
70+
std::string attrName;
7171
};
7272

7373
/// The matcher that matches operations that have the `ConstantLike` trait, and

mlir/include/mlir/Query/Matcher/ExtraMatchers.h

Lines changed: 135 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
#include "MatchFinder.h"
1717
#include "MatchersInternal.h"
18+
#include "mlir/IR/Region.h"
19+
#include "mlir/Query/Query.h"
20+
#include "llvm/Support/raw_ostream.h"
1821

1922
namespace mlir {
2023

@@ -24,80 +27,161 @@ namespace extramatcher {
2427

2528
namespace detail {
2629

27-
class DefinitionsMatcher {
30+
class BackwardSliceMatcher {
2831
public:
29-
DefinitionsMatcher(matcher::DynMatcher &&InnerMatcher, unsigned Hops)
30-
: InnerMatcher(std::move(InnerMatcher)), Hops(Hops) {}
32+
BackwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
33+
: innerMatcher(std::move(innerMatcher)), hops(hops) {}
3134

3235
private:
33-
bool matches(Operation *op, matcher::BoundOperationsGraphBuilder &Bound,
34-
unsigned TempHops) {
35-
36-
llvm::DenseSet<mlir::Value> Ccache;
37-
llvm::SmallVector<std::pair<Operation *, size_t>, 4> TempStorage;
38-
TempStorage.push_back({op, TempHops});
39-
while (!TempStorage.empty()) {
40-
auto [CurrentOp, RemainingHops] = TempStorage.pop_back_val();
41-
42-
matcher::BoundOperationNode *CurrentNode =
43-
Bound.addNode(CurrentOp, true, true);
44-
if (RemainingHops == 0) {
45-
continue;
46-
}
36+
bool matches(Operation *op, SetVector<Operation *> &backwardSlice,
37+
QueryOptions &options, unsigned tempHops) {
4738

48-
for (auto Operand : CurrentOp->getOperands()) {
49-
if (auto DefiningOp = Operand.getDefiningOp()) {
50-
Bound.addEdge(CurrentOp, DefiningOp);
51-
if (!Ccache.contains(Operand)) {
52-
Ccache.insert(Operand);
53-
TempStorage.emplace_back(DefiningOp, RemainingHops - 1);
54-
}
55-
} else if (auto BlockArg = Operand.dyn_cast<BlockArgument>()) {
56-
auto *Block = BlockArg.getOwner();
39+
bool validSlice = true;
40+
if (op->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
41+
return false;
42+
}
5743

58-
if (Block->isEntryBlock() &&
59-
isa<FunctionOpInterface>(Block->getParentOp())) {
60-
continue;
44+
auto processValue = [&](Value value) {
45+
if (tempHops == 0) {
46+
return;
47+
}
48+
if (auto *definingOp = value.getDefiningOp()) {
49+
if (backwardSlice.count(definingOp) == 0)
50+
matches(definingOp, backwardSlice, options, tempHops - 1);
51+
} else if (auto blockArg = dyn_cast<BlockArgument>(value)) {
52+
if (options.omitBlockArguments)
53+
return;
54+
Block *block = blockArg.getOwner();
55+
56+
Operation *parentOp = block->getParentOp();
57+
58+
if (parentOp && backwardSlice.count(parentOp) == 0) {
59+
if (parentOp->getNumRegions() == 1 &&
60+
parentOp->getRegion(0).getBlocks().size() == 1) {
61+
validSlice = false;
62+
return;
63+
};
64+
matches(parentOp, backwardSlice, options, tempHops - 1);
65+
}
66+
} else {
67+
validSlice = false;
68+
return;
69+
}
70+
};
71+
72+
if (!options.omitUsesFromAbove) {
73+
llvm::for_each(op->getRegions(), [&](Region &region) {
74+
SmallPtrSet<Region *, 4> descendents;
75+
region.walk(
76+
[&](Region *childRegion) { descendents.insert(childRegion); });
77+
region.walk([&](Operation *op) {
78+
for (OpOperand &operand : op->getOpOperands()) {
79+
if (!descendents.contains(operand.get().getParentRegion()))
80+
processValue(operand.get());
81+
if (!validSlice)
82+
return;
6183
}
84+
});
85+
});
86+
}
6287

63-
Operation *ParentOp = BlockArg.getOwner()->getParentOp();
64-
if (ParentOp) {
65-
Bound.addEdge(CurrentOp, ParentOp);
66-
if (!!Ccache.contains(BlockArg)) {
67-
Ccache.insert(BlockArg);
68-
TempStorage.emplace_back(ParentOp, RemainingHops - 1);
69-
}
70-
}
71-
}
88+
llvm::for_each(op->getOperands(), [&](Value operand) {
89+
processValue(operand);
90+
if (!validSlice)
91+
return;
92+
});
93+
backwardSlice.insert(op);
94+
if (!validSlice) {
95+
return false;
96+
}
97+
return true;
98+
}
99+
100+
public:
101+
bool match(Operation *op, SetVector<Operation *> &backwardSlice,
102+
QueryOptions &options) {
103+
if (innerMatcher.match(op) && matches(op, backwardSlice, options, hops)) {
104+
if (!options.inclusive) {
105+
backwardSlice.remove(op);
72106
}
107+
return true;
73108
}
74-
// We need at least 1 defining op
75-
return Ccache.size() >= 2;
109+
return false;
76110
}
77111

112+
private:
113+
matcher::DynMatcher innerMatcher;
114+
unsigned hops;
115+
};
116+
117+
class ForwardSliceMatcher {
78118
public:
79-
bool match(Operation *op, matcher::BoundOperationsGraphBuilder &Bound) {
80-
if (InnerMatcher.match(op) && matches(op, Bound, Hops)) {
119+
ForwardSliceMatcher(matcher::DynMatcher &&innerMatcher, unsigned hops)
120+
: innerMatcher(std::move(innerMatcher)), hops(hops) {}
121+
122+
private:
123+
bool matches(Operation *op, SetVector<Operation *> &forwardSlice,
124+
QueryOptions &options, unsigned tempHops) {
125+
126+
if (tempHops == 0) {
127+
forwardSlice.insert(op);
128+
return true;
129+
}
130+
131+
for (Region &region : op->getRegions())
132+
for (Block &block : region)
133+
for (Operation &blockOp : block)
134+
if (forwardSlice.count(&blockOp) == 0)
135+
matches(&blockOp, forwardSlice, options, tempHops - 1);
136+
for (Value result : op->getResults()) {
137+
for (Operation *userOp : result.getUsers())
138+
if (forwardSlice.count(userOp) == 0)
139+
matches(userOp, forwardSlice, options, tempHops - 1);
140+
}
141+
142+
forwardSlice.insert(op);
143+
return true;
144+
}
145+
146+
public:
147+
bool match(Operation *op, SetVector<Operation *> &forwardSlice,
148+
QueryOptions &options) {
149+
if (innerMatcher.match(op) && matches(op, forwardSlice, options, hops)) {
150+
if (!options.inclusive) {
151+
forwardSlice.remove(op);
152+
}
153+
SmallVector<Operation *, 0> v(forwardSlice.takeVector());
154+
forwardSlice.insert(v.rbegin(), v.rend());
81155
return true;
82156
}
83157
return false;
84158
}
85159

86160
private:
87-
matcher::DynMatcher InnerMatcher;
88-
unsigned Hops;
161+
matcher::DynMatcher innerMatcher;
162+
unsigned hops;
89163
};
164+
90165
} // namespace detail
91166

92-
inline detail::DefinitionsMatcher
93-
definedBy(mlir::query::matcher::DynMatcher InnerMatcher) {
94-
return detail::DefinitionsMatcher(std::move(InnerMatcher), 1);
167+
inline detail::BackwardSliceMatcher
168+
definedBy(mlir::query::matcher::DynMatcher innerMatcher) {
169+
return detail::BackwardSliceMatcher(std::move(innerMatcher), 1);
170+
}
171+
172+
inline detail::BackwardSliceMatcher
173+
getDefinitions(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
174+
return detail::BackwardSliceMatcher(std::move(innerMatcher), hops);
175+
}
176+
177+
inline detail::ForwardSliceMatcher
178+
usedBy(mlir::query::matcher::DynMatcher innerMatcher) {
179+
return detail::ForwardSliceMatcher(std::move(innerMatcher), 1);
95180
}
96181

97-
inline detail::DefinitionsMatcher
98-
getDefinitions(mlir::query::matcher::DynMatcher InnerMatcher, unsigned Hops) {
99-
assert(Hops > 0 && "hops must be >= 1");
100-
return detail::DefinitionsMatcher(std::move(InnerMatcher), Hops);
182+
inline detail::ForwardSliceMatcher
183+
getUses(mlir::query::matcher::DynMatcher innerMatcher, unsigned hops) {
184+
return detail::ForwardSliceMatcher(std::move(innerMatcher), hops);
101185
}
102186

103187
} // namespace extramatcher

mlir/include/mlir/Query/Matcher/MatchFinder.h

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,18 @@ namespace mlir::query::matcher {
2323
class MatchFinder {
2424
public:
2525
// Returns all operations that match the given matcher.
26-
static BoundOperationsGraphBuilder getMatches(Operation *root,
27-
DynMatcher matcher) {
28-
29-
BoundOperationsGraphBuilder Bound;
26+
static SetVector<Operation *>
27+
getMatches(Operation *root, QueryOptions &options, DynMatcher matcher) {
28+
SetVector<Operation *> backwardSlice;
3029
root->walk([&](Operation *subOp) {
3130
if (matcher.match(subOp)) {
32-
matcher::BoundOperationNode *currentNode = Bound.addNode(subOp);
33-
} else if (matcher.match(subOp, Bound)) {
31+
backwardSlice.insert(subOp);
32+
} else {
33+
matcher.match(subOp, backwardSlice, options);
3434
////
3535
}
3636
});
37-
return Bound;
37+
return backwardSlice;
3838
}
3939
};
4040

mlir/include/mlir/Query/Matcher/MatchersInternal.h

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -17,59 +17,13 @@
1717
#include <unordered_set>
1818
#include <vector>
1919

20-
namespace mlir::query::matcher {
21-
22-
struct BoundOperationNode {
23-
Operation *op;
24-
std::vector<BoundOperationNode *> Parents;
25-
std::vector<BoundOperationNode *> Children;
26-
27-
bool IsRootNode;
28-
bool DetailedPrinting;
29-
30-
BoundOperationNode(Operation *op, bool IsRootNode = false,
31-
bool DetailedPrinting = false)
32-
: op(op), IsRootNode(IsRootNode), DetailedPrinting(DetailedPrinting) {}
33-
};
20+
namespace mlir {
21+
namespace query {
22+
struct QueryOptions;
23+
}
24+
} // namespace mlir
3425

35-
class BoundOperationsGraphBuilder {
36-
public:
37-
BoundOperationNode *addNode(Operation *op, bool IsRootNode = false,
38-
bool DetailedPrinting = false) {
39-
auto It = Nodes.find(op);
40-
if (It != Nodes.end()) {
41-
return It->second.get();
42-
}
43-
auto Node =
44-
std::make_unique<BoundOperationNode>(op, IsRootNode, DetailedPrinting);
45-
BoundOperationNode *NodePtr = Node.get();
46-
Nodes[op] = std::move(Node);
47-
return NodePtr;
48-
}
49-
50-
void addEdge(Operation *parentOp, Operation *childOp) {
51-
BoundOperationNode *ParentNode = addNode(parentOp, false, false);
52-
BoundOperationNode *ChildNode = addNode(childOp, false, false);
53-
54-
ParentNode->Children.push_back(ChildNode);
55-
ChildNode->Parents.push_back(ParentNode);
56-
}
57-
58-
BoundOperationNode *getNode(Operation *op) const {
59-
auto It = Nodes.find(op);
60-
return It != Nodes.end() ? It->second.get() : nullptr;
61-
}
62-
63-
const llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> &
64-
getNodes() const {
65-
return Nodes;
66-
}
67-
68-
private:
69-
llvm::MapVector<Operation *, std::unique_ptr<BoundOperationNode>> Nodes;
70-
};
71-
72-
// Type traIt to detect if a matcher has a match(Operation*) method
26+
namespace mlir::query::matcher {
7327
template <typename T, typename = void>
7428
struct has_simple_match : std::false_type {};
7529

@@ -78,15 +32,14 @@ struct has_simple_match<T, std::void_t<decltype(std::declval<T>().match(
7832
std::declval<Operation *>()))>>
7933
: std::true_type {};
8034

81-
// Type traIt to detect if a matcher has a match(Operation*,
82-
// BoundOperationsGraphBuilder&) method
8335
template <typename T, typename = void>
8436
struct has_bound_match : std::false_type {};
8537

8638
template <typename T>
8739
struct has_bound_match<T, std::void_t<decltype(std::declval<T>().match(
8840
std::declval<Operation *>(),
89-
std::declval<BoundOperationsGraphBuilder &>()))>>
41+
std::declval<SetVector<Operation *> &>(),
42+
std::declval<QueryOptions &>()))>>
9043
: std::true_type {};
9144

9245
// Generic interface for matchers on an MLIR operation.
@@ -95,7 +48,8 @@ class MatcherInterface
9548
public:
9649
virtual ~MatcherInterface() = default;
9750
virtual bool match(Operation *op) = 0;
98-
virtual bool match(Operation *op, BoundOperationsGraphBuilder &bound) = 0;
51+
virtual bool match(Operation *op, SetVector<Operation *> &matchedOps,
52+
QueryOptions &options) = 0;
9953
};
10054

10155
// MatcherFnImpl takes a matcher function object and implements
@@ -111,9 +65,10 @@ class MatcherFnImpl : public MatcherInterface {
11165
return false;
11266
}
11367

114-
bool match(Operation *op, BoundOperationsGraphBuilder &bound) override {
68+
bool match(Operation *op, SetVector<Operation *> &matchedOps,
69+
QueryOptions &options) override {
11570
if constexpr (has_bound_match<MatcherFn>::value)
116-
return matcherFn.match(op, bound);
71+
return matcherFn.match(op, matchedOps, options);
11772
return false;
11873
}
11974

@@ -138,8 +93,9 @@ class DynMatcher {
13893
}
13994

14095
bool match(Operation *op) const { return implementation->match(op); }
141-
bool match(Operation *op, BoundOperationsGraphBuilder &bound) const {
142-
return implementation->match(op, bound);
96+
bool match(Operation *op, SetVector<Operation *> &matchedOps,
97+
QueryOptions &options) const {
98+
return implementation->match(op, matchedOps, options);
14399
}
144100

145101
void setFunctionName(StringRef name) { functionName = name.str(); }

mlir/include/mlir/Query/Matcher/VariantValue.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ class VariantValue {
100100

101101
// String representation of the type of the value.
102102
std::string getTypeAsString() const;
103+
explicit operator bool() const { return hasValue(); }
104+
bool hasValue() const { return type != ValueType::Nothing; }
103105

104106
private:
105107
void reset();

0 commit comments

Comments
 (0)