Skip to content

Commit 8b9f8db

Browse files
devajithvsjpienaar
authored andcommitted
[mlir][matchers] Add m_Op(StringRef) and m_Attr matchers
This patch introduces support for m_Op with a StringRef argument and m_Attr matchers. These matchers will be very useful for mlir-query that is being developed currently. Submitting this patch separately to reduce the final patch size and make it easier to upstream mlir-query. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D147262
1 parent 718729e commit 8b9f8db

File tree

3 files changed

+84
-0
lines changed

3 files changed

+84
-0
lines changed

mlir/include/mlir/IR/Matchers.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,22 @@ struct constant_op_matcher {
5252
bool match(Operation *op) { return op->hasTrait<OpTrait::ConstantLike>(); }
5353
};
5454

55+
/// The matcher that matches operations that have the specified op name.
56+
struct NameOpMatcher {
57+
NameOpMatcher(StringRef name) : name(name) {}
58+
bool match(Operation *op) { return op->getName().getStringRef() == name; }
59+
60+
StringRef name;
61+
};
62+
63+
/// The matcher that matches operations that have the specified attribute name.
64+
struct AttrOpMatcher {
65+
AttrOpMatcher(StringRef attrName) : attrName(attrName) {}
66+
bool match(Operation *op) { return op->hasAttr(attrName); }
67+
68+
StringRef attrName;
69+
};
70+
5571
/// The matcher that matches operations that have the `ConstantLike` trait, and
5672
/// binds the folded attribute value.
5773
template <typename AttrT>
@@ -83,6 +99,29 @@ struct constant_op_binder {
8399
}
84100
};
85101

102+
/// The matcher that matches operations that have the specified attribute
103+
/// name, and binds the attribute value.
104+
template <typename AttrT>
105+
struct AttrOpBinder {
106+
/// Creates a matcher instance that binds the attribute value to
107+
/// bind_value if match succeeds.
108+
AttrOpBinder(StringRef attrName, AttrT *bindValue)
109+
: attrName(attrName), bindValue(bindValue) {}
110+
/// Creates a matcher instance that doesn't bind if match succeeds.
111+
AttrOpBinder(StringRef attrName) : attrName(attrName), bindValue(nullptr) {}
112+
113+
bool match(Operation *op) {
114+
if (auto attr = op->getAttrOfType<AttrT>(attrName)) {
115+
if (bindValue)
116+
*bindValue = attr;
117+
return true;
118+
}
119+
return false;
120+
}
121+
StringRef attrName;
122+
AttrT *bindValue;
123+
};
124+
86125
/// The matcher that matches a constant scalar / vector splat / tensor splat
87126
/// float operation and binds the constant float value.
88127
struct constant_float_op_binder {
@@ -249,13 +288,30 @@ inline detail::constant_op_matcher m_Constant() {
249288
return detail::constant_op_matcher();
250289
}
251290

291+
/// Matches a named attribute operation.
292+
inline detail::AttrOpMatcher m_Attr(StringRef attrName) {
293+
return detail::AttrOpMatcher(attrName);
294+
}
295+
296+
/// Matches a named operation.
297+
inline detail::NameOpMatcher m_Op(StringRef opName) {
298+
return detail::NameOpMatcher(opName);
299+
}
300+
252301
/// Matches a value from a constant foldable operation and writes the value to
253302
/// bind_value.
254303
template <typename AttrT>
255304
inline detail::constant_op_binder<AttrT> m_Constant(AttrT *bind_value) {
256305
return detail::constant_op_binder<AttrT>(bind_value);
257306
}
258307

308+
/// Matches a named attribute operation and writes the value to bind_value.
309+
template <typename AttrT>
310+
inline detail::AttrOpBinder<AttrT> m_Attr(StringRef attrName,
311+
AttrT *bindValue) {
312+
return detail::AttrOpBinder<AttrT>(attrName, bindValue);
313+
}
314+
259315
/// Matches a constant scalar / vector splat / tensor splat float (both positive
260316
/// and negative) zero.
261317
inline detail::constant_float_predicate_matcher m_AnyZeroFloat() {

mlir/test/IR/test-matchers.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,14 @@ func.func @test2(%a: f32) -> f32 {
4141
// CHECK-LABEL: test2
4242
// CHECK: Pattern add(add(a, constant), a) matched and bound constant to: 1.000000e+00
4343
// CHECK: Pattern add(add(a, constant), a) matched
44+
45+
func.func @test3(%a: f32) -> f32 {
46+
%0 = "test.name"() {value = 1.0 : f32} : () -> f32
47+
%1 = arith.addf %a, %0: f32
48+
%2 = arith.mulf %a, %1 fastmath<fast>: f32
49+
return %2: f32
50+
}
51+
52+
// CHECK-LABEL: test3
53+
// CHECK: Pattern mul(*, add(*, m_Op("test.name"))) matched
54+
// CHECK: Pattern m_Attr("fastmath") matched and bound value to: fast

mlir/test/lib/IR/TestMatchers.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,13 +148,30 @@ void test2(FunctionOpInterface f) {
148148
llvm::outs() << "Pattern add(add(a, constant), a) matched\n";
149149
}
150150

151+
void test3(FunctionOpInterface f) {
152+
arith::FastMathFlagsAttr fastMathAttr;
153+
auto p = m_Op<arith::MulFOp>(m_Any(),
154+
m_Op<arith::AddFOp>(m_Any(), m_Op("test.name")));
155+
auto p1 = m_Attr("fastmath", &fastMathAttr);
156+
157+
// Last operation that is not the terminator.
158+
Operation *lastOp = f.getFunctionBody().front().back().getPrevNode();
159+
if (p.match(lastOp))
160+
llvm::outs() << "Pattern mul(*, add(*, m_Op(\"test.name\"))) matched\n";
161+
if (p1.match(lastOp))
162+
llvm::outs() << "Pattern m_Attr(\"fastmath\") matched and bound value to: "
163+
<< fastMathAttr.getValue() << "\n";
164+
}
165+
151166
void TestMatchers::runOnOperation() {
152167
auto f = getOperation();
153168
llvm::outs() << f.getName() << "\n";
154169
if (f.getName() == "test1")
155170
test1(f);
156171
if (f.getName() == "test2")
157172
test2(f);
173+
if (f.getName() == "test3")
174+
test3(f);
158175
}
159176

160177
namespace mlir {

0 commit comments

Comments
 (0)