@@ -234,41 +234,50 @@ class Pattern {
234
234
// RewritePattern
235
235
// ===----------------------------------------------------------------------===//
236
236
237
- // / RewritePattern is the common base class for all DAG to DAG replacements.
238
- // / There are two possible usages of this class:
239
- // / * Multi-step RewritePattern with "match" and "rewrite"
240
- // / - By overloading the "match" and "rewrite" functions, the user can
241
- // / separate the concerns of matching and rewriting.
242
- // / * Single-step RewritePattern with "matchAndRewrite"
243
- // / - By overloading the "matchAndRewrite" function, the user can perform
244
- // / the rewrite in the same call as the match.
245
- // /
246
- class RewritePattern : public Pattern {
247
- public:
248
- virtual ~RewritePattern () = default ;
237
+ namespace detail {
238
+ // / Helper class that derives from a RewritePattern class and provides separate
239
+ // / `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
240
+ template <typename PatternT>
241
+ class SplitMatchAndRewriteImpl : public PatternT {
242
+ using PatternT::PatternT;
249
243
250
244
// / Rewrite the IR rooted at the specified operation with the result of
251
245
// / this pattern, generating any new operations with the specified
252
- // / builder. If an unexpected error is encountered (an internal
253
- // / compiler error), it is emitted through the normal MLIR diagnostic
254
- // / hooks and the IR is left in a valid state.
255
- virtual void rewrite (Operation *op, PatternRewriter &rewriter) const ;
246
+ // / rewriter.
247
+ virtual void rewrite (typename PatternT::OperationT op,
248
+ PatternRewriter &rewriter) const = 0;
256
249
257
250
// / Attempt to match against code rooted at the specified operation,
258
251
// / which is the same operation code as getRootKind().
259
- virtual LogicalResult match (Operation * op) const ;
252
+ virtual LogicalResult match (typename PatternT::OperationT op) const = 0 ;
260
253
261
- // / Attempt to match against code rooted at the specified operation,
262
- // / which is the same operation code as getRootKind(). If successful, this
263
- // / function will automatically perform the rewrite.
264
- virtual LogicalResult matchAndRewrite (Operation *op,
265
- PatternRewriter &rewriter) const {
254
+ LogicalResult matchAndRewrite (typename PatternT::OperationT op,
255
+ PatternRewriter &rewriter) const final {
266
256
if (succeeded (match (op))) {
267
257
rewrite (op, rewriter);
268
258
return success ();
269
259
}
270
260
return failure ();
271
261
}
262
+ };
263
+ } // namespace detail
264
+
265
+ // / RewritePattern is the common base class for all DAG to DAG replacements.
266
+ // / By overloading the "matchAndRewrite" function, the user can perform the
267
+ // / rewrite in the same call as the match.
268
+ // /
269
+ class RewritePattern : public Pattern {
270
+ public:
271
+ using OperationT = Operation *;
272
+ using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
273
+
274
+ virtual ~RewritePattern () = default ;
275
+
276
+ // / Attempt to match against code rooted at the specified operation,
277
+ // / which is the same operation code as getRootKind(). If successful, this
278
+ // / function will automatically perform the rewrite.
279
+ virtual LogicalResult matchAndRewrite (Operation *op,
280
+ PatternRewriter &rewriter) const = 0;
272
281
273
282
// / This method provides a convenient interface for creating and initializing
274
283
// / derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
317
326
// / class or Interface.
318
327
template <typename SourceOp>
319
328
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
329
+ using OperationT = SourceOp;
320
330
using RewritePattern::RewritePattern;
321
331
322
- // / Wrappers around the RewritePattern methods that pass the derived op type.
323
- void rewrite (Operation *op, PatternRewriter &rewriter) const final {
324
- rewrite (cast<SourceOp>(op), rewriter);
325
- }
326
- LogicalResult match (Operation *op) const final {
327
- return match (cast<SourceOp>(op));
328
- }
332
+ // / Wrapper around the RewritePattern method that passes the derived op type.
329
333
LogicalResult matchAndRewrite (Operation *op,
330
334
PatternRewriter &rewriter) const final {
331
335
return matchAndRewrite (cast<SourceOp>(op), rewriter);
332
336
}
333
337
334
- // / Rewrite and Match methods that operate on the SourceOp type. These must be
335
- // / overridden by the derived pattern class.
336
- virtual void rewrite (SourceOp op, PatternRewriter &rewriter) const {
337
- llvm_unreachable (" must override rewrite or matchAndRewrite" );
338
- }
339
- virtual LogicalResult match (SourceOp op) const {
340
- llvm_unreachable (" must override match or matchAndRewrite" );
341
- }
338
+ // / Method that operates on the SourceOp type. Must be overridden by the
339
+ // / derived pattern class.
342
340
virtual LogicalResult matchAndRewrite (SourceOp op,
343
- PatternRewriter &rewriter) const {
344
- if (succeeded (match (op))) {
345
- rewrite (op, rewriter);
346
- return success ();
347
- }
348
- return failure ();
349
- }
341
+ PatternRewriter &rewriter) const = 0;
350
342
};
351
343
} // namespace detail
352
344
@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
356
348
template <typename SourceOp>
357
349
struct OpRewritePattern
358
350
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
351
+ using SplitMatchAndRewrite =
352
+ detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
353
+
359
354
// / Patterns must specify the root operation name they match against, and can
360
355
// / also specify the benefit of the pattern matching and a list of generated
361
356
// / ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
371
366
template <typename SourceOp>
372
367
struct OpInterfaceRewritePattern
373
368
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
369
+ using SplitMatchAndRewrite =
370
+ detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
371
+
374
372
OpInterfaceRewritePattern (MLIRContext *context, PatternBenefit benefit = 1 )
375
373
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
376
374
Pattern::MatchInterfaceOpTypeTag (), SourceOp::getInterfaceID(),
0 commit comments