Skip to content

Commit f64b3bb

Browse files
authored
[mlir][llvm] Op interface LLVM converter (llvm#143922)
Adds a utility conversion class for rewriting op interface instances targeting LLVM dialect.
1 parent 4e80a03 commit f64b3bb

File tree

3 files changed

+51
-24
lines changed

3 files changed

+51
-24
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
9292
PatternBenefit benefit = 1);
9393

9494
protected:
95+
/// See `ConversionPattern::ConversionPattern` for information on the other
96+
/// available constructors.
97+
using ConversionPattern::ConversionPattern;
98+
9599
/// Returns the LLVM dialect.
96100
LLVM::LLVMDialect &getDialect() const;
97101

@@ -234,6 +238,47 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
234238
using ConvertToLLVMPattern::matchAndRewrite;
235239
};
236240

241+
/// Utility class for operation conversions targeting the LLVM dialect that
242+
/// allows for matching and rewriting against an instance of an OpInterface
243+
/// class.
244+
template <typename SourceOp>
245+
class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
246+
public:
247+
explicit ConvertOpInterfaceToLLVMPattern(
248+
const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
249+
: ConvertToLLVMPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
250+
SourceOp::getInterfaceID(), benefit,
251+
&typeConverter.getContext()) {}
252+
253+
/// Wrappers around the RewritePattern methods that pass the derived op type.
254+
LogicalResult
255+
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
256+
ConversionPatternRewriter &rewriter) const final {
257+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
258+
}
259+
LogicalResult
260+
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
261+
ConversionPatternRewriter &rewriter) const final {
262+
return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
263+
}
264+
265+
/// Methods that operate on the SourceOp type. One of these must be
266+
/// overridden by the derived pattern class.
267+
virtual LogicalResult
268+
matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
269+
ConversionPatternRewriter &rewriter) const {
270+
llvm_unreachable("matchAndRewrite is not implemented");
271+
}
272+
virtual LogicalResult
273+
matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
274+
ConversionPatternRewriter &rewriter) const {
275+
return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
276+
}
277+
278+
private:
279+
using ConvertToLLVMPattern::matchAndRewrite;
280+
};
281+
237282
/// Generic implementation of one-to-one conversion from "SourceOp" to
238283
/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
239284
/// Upholds a convention that multi-result operations get converted into an

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,18 @@ namespace {
2424
/// Generic one-to-one conversion of simply mappable operations into calls
2525
/// to their respective LLVM intrinsics.
2626
struct AMXIntrinsicOpConversion
27-
: public OpInterfaceConversionPattern<amx::AMXIntrinsicOp> {
28-
using OpInterfaceConversionPattern<
29-
amx::AMXIntrinsicOp>::OpInterfaceConversionPattern;
30-
31-
AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
32-
PatternBenefit benefit = 1)
33-
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
34-
benefit),
35-
typeConverter(typeConverter) {}
27+
: public ConvertOpInterfaceToLLVMPattern<amx::AMXIntrinsicOp> {
28+
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
3629

3730
LogicalResult
3831
matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
3932
ConversionPatternRewriter &rewriter) const override {
33+
const LLVMTypeConverter &typeConverter = *getTypeConverter();
4034
return LLVM::detail::intrinsicRewrite(
4135
op, rewriter.getStringAttr(op.getIntrinsicName()),
4236
op.getIntrinsicOperands(operands, typeConverter, rewriter),
4337
typeConverter, rewriter);
4438
}
45-
46-
private:
47-
const LLVMTypeConverter &typeConverter;
4839
};
4940

5041
} // namespace

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,27 +23,18 @@ namespace {
2323
/// Generic one-to-one conversion of simply mappable operations into calls
2424
/// to their respective LLVM intrinsics.
2525
struct X86IntrinsicOpConversion
26-
: public OpInterfaceConversionPattern<x86vector::X86IntrinsicOp> {
27-
using OpInterfaceConversionPattern<
28-
x86vector::X86IntrinsicOp>::OpInterfaceConversionPattern;
29-
30-
X86IntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
31-
PatternBenefit benefit = 1)
32-
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
33-
benefit),
34-
typeConverter(typeConverter) {}
26+
: public ConvertOpInterfaceToLLVMPattern<x86vector::X86IntrinsicOp> {
27+
using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
3528

3629
LogicalResult
3730
matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands,
3831
ConversionPatternRewriter &rewriter) const override {
32+
const LLVMTypeConverter &typeConverter = *getTypeConverter();
3933
return LLVM::detail::intrinsicRewrite(
4034
op, rewriter.getStringAttr(op.getIntrinsicName()),
4135
op.getIntrinsicOperands(operands, typeConverter, rewriter),
4236
typeConverter, rewriter);
4337
}
44-
45-
private:
46-
const LLVMTypeConverter &typeConverter;
4738
};
4839

4940
} // namespace

0 commit comments

Comments
 (0)