@@ -52,6 +52,22 @@ struct constant_op_matcher {
52
52
bool match (Operation *op) { return op->hasTrait <OpTrait::ConstantLike>(); }
53
53
};
54
54
55
+ // / The matcher that matches operations that have the specified op name.
56
+ struct NameOpMatcher {
57
+ NameOpMatcher (StringRef name) : name(name) {}
58
+ bool match (Operation *op) { return op->getName ().getStringRef () == name; }
59
+
60
+ StringRef name;
61
+ };
62
+
63
+ // / The matcher that matches operations that have the specified attribute name.
64
+ struct AttrOpMatcher {
65
+ AttrOpMatcher (StringRef attrName) : attrName(attrName) {}
66
+ bool match (Operation *op) { return op->hasAttr (attrName); }
67
+
68
+ StringRef attrName;
69
+ };
70
+
55
71
// / The matcher that matches operations that have the `ConstantLike` trait, and
56
72
// / binds the folded attribute value.
57
73
template <typename AttrT>
@@ -83,6 +99,29 @@ struct constant_op_binder {
83
99
}
84
100
};
85
101
102
+ // / The matcher that matches operations that have the specified attribute
103
+ // / name, and binds the attribute value.
104
+ template <typename AttrT>
105
+ struct AttrOpBinder {
106
+ // / Creates a matcher instance that binds the attribute value to
107
+ // / bind_value if match succeeds.
108
+ AttrOpBinder (StringRef attrName, AttrT *bindValue)
109
+ : attrName(attrName), bindValue(bindValue) {}
110
+ // / Creates a matcher instance that doesn't bind if match succeeds.
111
+ AttrOpBinder (StringRef attrName) : attrName(attrName), bindValue(nullptr ) {}
112
+
113
+ bool match (Operation *op) {
114
+ if (auto attr = op->getAttrOfType <AttrT>(attrName)) {
115
+ if (bindValue)
116
+ *bindValue = attr;
117
+ return true ;
118
+ }
119
+ return false ;
120
+ }
121
+ StringRef attrName;
122
+ AttrT *bindValue;
123
+ };
124
+
86
125
// / The matcher that matches a constant scalar / vector splat / tensor splat
87
126
// / float operation and binds the constant float value.
88
127
struct constant_float_op_binder {
@@ -249,13 +288,30 @@ inline detail::constant_op_matcher m_Constant() {
249
288
return detail::constant_op_matcher ();
250
289
}
251
290
291
+ // / Matches a named attribute operation.
292
+ inline detail::AttrOpMatcher m_Attr (StringRef attrName) {
293
+ return detail::AttrOpMatcher (attrName);
294
+ }
295
+
296
+ // / Matches a named operation.
297
+ inline detail::NameOpMatcher m_Op (StringRef opName) {
298
+ return detail::NameOpMatcher (opName);
299
+ }
300
+
252
301
// / Matches a value from a constant foldable operation and writes the value to
253
302
// / bind_value.
254
303
template <typename AttrT>
255
304
inline detail::constant_op_binder<AttrT> m_Constant (AttrT *bind_value) {
256
305
return detail::constant_op_binder<AttrT>(bind_value);
257
306
}
258
307
308
+ // / Matches a named attribute operation and writes the value to bind_value.
309
+ template <typename AttrT>
310
+ inline detail::AttrOpBinder<AttrT> m_Attr (StringRef attrName,
311
+ AttrT *bindValue) {
312
+ return detail::AttrOpBinder<AttrT>(attrName, bindValue);
313
+ }
314
+
259
315
// / Matches a constant scalar / vector splat / tensor splat float (both positive
260
316
// / and negative) zero.
261
317
inline detail::constant_float_predicate_matcher m_AnyZeroFloat () {
0 commit comments