Skip to content

Commit 18105b5

Browse files
add new class
1 parent f4878cb commit 18105b5

File tree

10 files changed

+147
-179
lines changed

10 files changed

+147
-179
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 8 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ LogicalResult oneToOneRewrite(
4040
/// during the entire pattern lifetime.
4141
class ConvertToLLVMPattern : public ConversionPattern {
4242
public:
43+
using SplitMatchAndRewrite =
44+
detail::ConversionSplitMatchAndRewriteImpl<ConvertToLLVMPattern>;
45+
4346
ConvertToLLVMPattern(StringRef rootOpName, MLIRContext *context,
4447
const LLVMTypeConverter &typeConverter,
4548
PatternBenefit benefit = 1);
@@ -142,9 +145,12 @@ class ConvertToLLVMPattern : public ConversionPattern {
142145
template <typename SourceOp>
143146
class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
144147
public:
148+
using OperationT = SourceOp;
145149
using OpAdaptor = typename SourceOp::Adaptor;
146150
using OneToNOpAdaptor =
147151
typename SourceOp::template GenericAdaptor<ArrayRef<ValueRange>>;
152+
using SplitMatchAndRewrite = detail::ConversionSplitMatchAndRewriteImpl<
153+
ConvertOpToLLVMPattern<SourceOp>>;
148154

149155
explicit ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter,
150156
PatternBenefit benefit = 1)
@@ -153,19 +159,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
153159
benefit) {}
154160

155161
/// Wrappers around the RewritePattern methods that pass the derived op type.
156-
void rewrite(Operation *op, ArrayRef<Value> operands,
157-
ConversionPatternRewriter &rewriter) const final {
158-
auto sourceOp = cast<SourceOp>(op);
159-
rewrite(sourceOp, OpAdaptor(operands, sourceOp), rewriter);
160-
}
161-
void rewrite(Operation *op, ArrayRef<ValueRange> operands,
162-
ConversionPatternRewriter &rewriter) const final {
163-
auto sourceOp = cast<SourceOp>(op);
164-
rewrite(sourceOp, OneToNOpAdaptor(operands, sourceOp), rewriter);
165-
}
166-
LogicalResult match(Operation *op) const final {
167-
return match(cast<SourceOp>(op));
168-
}
169162
LogicalResult
170163
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
171164
ConversionPatternRewriter &rewriter) const final {
@@ -180,28 +173,12 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
180173
rewriter);
181174
}
182175

183-
/// Rewrite and Match methods that operate on the SourceOp type. These must be
176+
/// Methods that operate on the SourceOp type. One of these must be
184177
/// overridden by the derived pattern class.
185-
virtual LogicalResult match(SourceOp op) const {
186-
llvm_unreachable("must override match or matchAndRewrite");
187-
}
188-
virtual void rewrite(SourceOp op, OpAdaptor adaptor,
189-
ConversionPatternRewriter &rewriter) const {
190-
llvm_unreachable("must override rewrite or matchAndRewrite");
191-
}
192-
virtual void rewrite(SourceOp op, OneToNOpAdaptor adaptor,
193-
ConversionPatternRewriter &rewriter) const {
194-
SmallVector<Value> oneToOneOperands =
195-
getOneToOneAdaptorOperands(adaptor.getOperands());
196-
rewrite(op, OpAdaptor(oneToOneOperands, adaptor), rewriter);
197-
}
198178
virtual LogicalResult
199179
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
200180
ConversionPatternRewriter &rewriter) const {
201-
if (failed(match(op)))
202-
return failure();
203-
rewrite(op, adaptor, rewriter);
204-
return success();
181+
llvm_unreachable("matchAndRewrite is not implemented");
205182
}
206183
virtual LogicalResult
207184
matchAndRewrite(SourceOp op, OneToNOpAdaptor adaptor,
@@ -212,7 +189,6 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
212189
}
213190

214191
private:
215-
using ConvertToLLVMPattern::match;
216192
using ConvertToLLVMPattern::matchAndRewrite;
217193
};
218194

mlir/include/mlir/IR/PatternMatch.h

Lines changed: 42 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -234,41 +234,50 @@ class Pattern {
234234
// RewritePattern
235235
//===----------------------------------------------------------------------===//
236236

237-
/// RewritePattern is the common base class for all DAG to DAG replacements.
238-
/// There are two possible usages of this class:
239-
/// * Multi-step RewritePattern with "match" and "rewrite"
240-
/// - By overloading the "match" and "rewrite" functions, the user can
241-
/// separate the concerns of matching and rewriting.
242-
/// * Single-step RewritePattern with "matchAndRewrite"
243-
/// - By overloading the "matchAndRewrite" function, the user can perform
244-
/// the rewrite in the same call as the match.
245-
///
246-
class RewritePattern : public Pattern {
247-
public:
248-
virtual ~RewritePattern() = default;
237+
namespace detail {
238+
/// Helper class that derives from a RewritePattern class and provides separate
239+
/// `match` and `rewrite` entry points instead of a combined `matchAndRewrite`.
240+
template <typename PatternT>
241+
class SplitMatchAndRewriteImpl : public PatternT {
242+
using PatternT::PatternT;
249243

250244
/// Rewrite the IR rooted at the specified operation with the result of
251245
/// this pattern, generating any new operations with the specified
252-
/// builder. If an unexpected error is encountered (an internal
253-
/// compiler error), it is emitted through the normal MLIR diagnostic
254-
/// hooks and the IR is left in a valid state.
255-
virtual void rewrite(Operation *op, PatternRewriter &rewriter) const;
246+
/// rewriter.
247+
virtual void rewrite(typename PatternT::OperationT op,
248+
PatternRewriter &rewriter) const = 0;
256249

257250
/// Attempt to match against code rooted at the specified operation,
258251
/// which is the same operation code as getRootKind().
259-
virtual LogicalResult match(Operation *op) const;
252+
virtual LogicalResult match(typename PatternT::OperationT op) const = 0;
260253

261-
/// Attempt to match against code rooted at the specified operation,
262-
/// which is the same operation code as getRootKind(). If successful, this
263-
/// function will automatically perform the rewrite.
264-
virtual LogicalResult matchAndRewrite(Operation *op,
265-
PatternRewriter &rewriter) const {
254+
LogicalResult matchAndRewrite(typename PatternT::OperationT op,
255+
PatternRewriter &rewriter) const final {
266256
if (succeeded(match(op))) {
267257
rewrite(op, rewriter);
268258
return success();
269259
}
270260
return failure();
271261
}
262+
};
263+
} // namespace detail
264+
265+
/// RewritePattern is the common base class for all DAG to DAG replacements.
266+
/// By overloading the "matchAndRewrite" function, the user can perform the
267+
/// rewrite in the same call as the match.
268+
///
269+
class RewritePattern : public Pattern {
270+
public:
271+
using OperationT = Operation *;
272+
using SplitMatchAndRewrite = detail::SplitMatchAndRewriteImpl<RewritePattern>;
273+
274+
virtual ~RewritePattern() = default;
275+
276+
/// Attempt to match against code rooted at the specified operation,
277+
/// which is the same operation code as getRootKind(). If successful, this
278+
/// function will automatically perform the rewrite.
279+
virtual LogicalResult matchAndRewrite(Operation *op,
280+
PatternRewriter &rewriter) const = 0;
272281

273282
/// This method provides a convenient interface for creating and initializing
274283
/// derived rewrite patterns of the given type `T`.
@@ -317,36 +326,19 @@ namespace detail {
317326
/// class or Interface.
318327
template <typename SourceOp>
319328
struct OpOrInterfaceRewritePatternBase : public RewritePattern {
329+
using OperationT = SourceOp;
320330
using RewritePattern::RewritePattern;
321331

322-
/// Wrappers around the RewritePattern methods that pass the derived op type.
323-
void rewrite(Operation *op, PatternRewriter &rewriter) const final {
324-
rewrite(cast<SourceOp>(op), rewriter);
325-
}
326-
LogicalResult match(Operation *op) const final {
327-
return match(cast<SourceOp>(op));
328-
}
332+
/// Wrapper around the RewritePattern method that passes the derived op type.
329333
LogicalResult matchAndRewrite(Operation *op,
330334
PatternRewriter &rewriter) const final {
331335
return matchAndRewrite(cast<SourceOp>(op), rewriter);
332336
}
333337

334-
/// Rewrite and Match methods that operate on the SourceOp type. These must be
335-
/// overridden by the derived pattern class.
336-
virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
337-
llvm_unreachable("must override rewrite or matchAndRewrite");
338-
}
339-
virtual LogicalResult match(SourceOp op) const {
340-
llvm_unreachable("must override match or matchAndRewrite");
341-
}
338+
/// Method that operates on the SourceOp type. Must be overridden by the
339+
/// derived pattern class.
342340
virtual LogicalResult matchAndRewrite(SourceOp op,
343-
PatternRewriter &rewriter) const {
344-
if (succeeded(match(op))) {
345-
rewrite(op, rewriter);
346-
return success();
347-
}
348-
return failure();
349-
}
341+
PatternRewriter &rewriter) const = 0;
350342
};
351343
} // namespace detail
352344

@@ -356,6 +348,9 @@ struct OpOrInterfaceRewritePatternBase : public RewritePattern {
356348
template <typename SourceOp>
357349
struct OpRewritePattern
358350
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
351+
using SplitMatchAndRewrite =
352+
detail::SplitMatchAndRewriteImpl<OpRewritePattern<SourceOp>>;
353+
359354
/// Patterns must specify the root operation name they match against, and can
360355
/// also specify the benefit of the pattern matching and a list of generated
361356
/// ops.
@@ -371,6 +366,9 @@ struct OpRewritePattern
371366
template <typename SourceOp>
372367
struct OpInterfaceRewritePattern
373368
: public detail::OpOrInterfaceRewritePatternBase<SourceOp> {
369+
using SplitMatchAndRewrite =
370+
detail::SplitMatchAndRewriteImpl<OpInterfaceRewritePattern<SourceOp>>;
371+
374372
OpInterfaceRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
375373
: detail::OpOrInterfaceRewritePatternBase<SourceOp>(
376374
Pattern::MatchInterfaceOpTypeTag(), SourceOp::getInterfaceID(),

0 commit comments

Comments
 (0)