-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[mlir][llvm] Op interface LLVM converter #143922
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Adds a utility pattern which decreases amount of required boilerplate in LLVM export legalization patterns targeting op interface instances.
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir-amx Author: Adam Siemieniuk (adam-smnk) ChangesAdds a utility conversion class for rewriting op interface instances targeting LLVM dialect. Full diff: https://github.com/llvm/llvm-project/pull/143922.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 7e946495e3e7f..503a2a7e6f0cd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -92,6 +92,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
PatternBenefit benefit = 1);
protected:
+ /// See `ConversionPattern::ConversionPattern` for information on the other
+ /// available constructors.
+ using ConversionPattern::ConversionPattern;
+
/// Returns the LLVM dialect.
LLVM::LLVMDialect &getDialect() const;
@@ -234,6 +238,47 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
using ConvertToLLVMPattern::matchAndRewrite;
};
+/// Utility class for operation conversions targeting the LLVM dialect that
+/// allows for matching and rewriting against an instance of an OpInterface
+/// class.
+template <typename SourceOp>
+class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
+public:
+ explicit ConvertOpInterfaceToLLVMPattern(
+ const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
+ : ConvertToLLVMPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
+ SourceOp::getInterfaceID(), benefit,
+ &typeConverter.getContext()) {}
+
+ /// Wrappers around the RewritePattern methods that pass the derived op type.
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
+
+ /// Methods that operate on the SourceOp type. One of these must be
+ /// overridden by the derived pattern class.
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ llvm_unreachable("matchAndRewrite is not implemented");
+ }
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
+
+private:
+ using ConvertToLLVMPattern::matchAndRewrite;
+};
+
/// Generic implementation of one-to-one conversion from "SourceOp" to
/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
/// Upholds a convention that multi-result operations get converted into an
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 37aebc9fab3eb..06e5f7c2196d2 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -24,27 +24,18 @@ namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct AMXIntrinsicOpConversion
- : public OpInterfaceConversionPattern<amx::AMXIntrinsicOp> {
- using OpInterfaceConversionPattern<
- amx::AMXIntrinsicOp>::OpInterfaceConversionPattern;
-
- AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
- : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
- benefit),
- typeConverter(typeConverter) {}
+ : public ConvertOpInterfaceToLLVMPattern<amx::AMXIntrinsicOp> {
+ using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
LogicalResult
matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ const LLVMTypeConverter &typeConverter = *getTypeConverter();
return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
}
-
-private:
- const LLVMTypeConverter &typeConverter;
};
} // namespace
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index b2fc2f3f40e8c..8e062488f58c8 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -23,27 +23,18 @@ namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct X86IntrinsicOpConversion
- : public OpInterfaceConversionPattern<x86vector::X86IntrinsicOp> {
- using OpInterfaceConversionPattern<
- x86vector::X86IntrinsicOp>::OpInterfaceConversionPattern;
-
- X86IntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
- : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
- benefit),
- typeConverter(typeConverter) {}
+ : public ConvertOpInterfaceToLLVMPattern<x86vector::X86IntrinsicOp> {
+ using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
LogicalResult
matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ const LLVMTypeConverter &typeConverter = *getTypeConverter();
return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
}
-
-private:
- const LLVMTypeConverter &typeConverter;
};
} // namespace
|
@llvm/pr-subscribers-mlir Author: Adam Siemieniuk (adam-smnk) ChangesAdds a utility conversion class for rewriting op interface instances targeting LLVM dialect. Full diff: https://github.com/llvm/llvm-project/pull/143922.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
index 7e946495e3e7f..503a2a7e6f0cd 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
@@ -92,6 +92,10 @@ class ConvertToLLVMPattern : public ConversionPattern {
PatternBenefit benefit = 1);
protected:
+ /// See `ConversionPattern::ConversionPattern` for information on the other
+ /// available constructors.
+ using ConversionPattern::ConversionPattern;
+
/// Returns the LLVM dialect.
LLVM::LLVMDialect &getDialect() const;
@@ -234,6 +238,47 @@ class ConvertOpToLLVMPattern : public ConvertToLLVMPattern {
using ConvertToLLVMPattern::matchAndRewrite;
};
+/// Utility class for operation conversions targeting the LLVM dialect that
+/// allows for matching and rewriting against an instance of an OpInterface
+/// class.
+template <typename SourceOp>
+class ConvertOpInterfaceToLLVMPattern : public ConvertToLLVMPattern {
+public:
+ explicit ConvertOpInterfaceToLLVMPattern(
+ const LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
+ : ConvertToLLVMPattern(typeConverter, Pattern::MatchInterfaceOpTypeTag(),
+ SourceOp::getInterfaceID(), benefit,
+ &typeConverter.getContext()) {}
+
+ /// Wrappers around the RewritePattern methods that pass the derived op type.
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ return matchAndRewrite(cast<SourceOp>(op), operands, rewriter);
+ }
+
+ /// Methods that operate on the SourceOp type. One of these must be
+ /// overridden by the derived pattern class.
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const {
+ llvm_unreachable("matchAndRewrite is not implemented");
+ }
+ virtual LogicalResult
+ matchAndRewrite(SourceOp op, ArrayRef<ValueRange> operands,
+ ConversionPatternRewriter &rewriter) const {
+ return matchAndRewrite(op, getOneToOneAdaptorOperands(operands), rewriter);
+ }
+
+private:
+ using ConvertToLLVMPattern::matchAndRewrite;
+};
+
/// Generic implementation of one-to-one conversion from "SourceOp" to
/// "TargetOp" where the latter belongs to the LLVM dialect or an equivalent.
/// Upholds a convention that multi-result operations get converted into an
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 37aebc9fab3eb..06e5f7c2196d2 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -24,27 +24,18 @@ namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct AMXIntrinsicOpConversion
- : public OpInterfaceConversionPattern<amx::AMXIntrinsicOp> {
- using OpInterfaceConversionPattern<
- amx::AMXIntrinsicOp>::OpInterfaceConversionPattern;
-
- AMXIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
- : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
- benefit),
- typeConverter(typeConverter) {}
+ : public ConvertOpInterfaceToLLVMPattern<amx::AMXIntrinsicOp> {
+ using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
LogicalResult
matchAndRewrite(amx::AMXIntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ const LLVMTypeConverter &typeConverter = *getTypeConverter();
return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
}
-
-private:
- const LLVMTypeConverter &typeConverter;
};
} // namespace
diff --git a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
index b2fc2f3f40e8c..8e062488f58c8 100644
--- a/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp
@@ -23,27 +23,18 @@ namespace {
/// Generic one-to-one conversion of simply mappable operations into calls
/// to their respective LLVM intrinsics.
struct X86IntrinsicOpConversion
- : public OpInterfaceConversionPattern<x86vector::X86IntrinsicOp> {
- using OpInterfaceConversionPattern<
- x86vector::X86IntrinsicOp>::OpInterfaceConversionPattern;
-
- X86IntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
- PatternBenefit benefit = 1)
- : OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
- benefit),
- typeConverter(typeConverter) {}
+ : public ConvertOpInterfaceToLLVMPattern<x86vector::X86IntrinsicOp> {
+ using ConvertOpInterfaceToLLVMPattern::ConvertOpInterfaceToLLVMPattern;
LogicalResult
matchAndRewrite(x86vector::X86IntrinsicOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
+ const LLVMTypeConverter &typeConverter = *getTypeConverter();
return LLVM::detail::intrinsicRewrite(
op, rewriter.getStringAttr(op.getIntrinsicName()),
op.getIntrinsicOperands(operands, typeConverter, rewriter),
typeConverter, rewriter);
}
-
-private:
- const LLVMTypeConverter &typeConverter;
};
} // namespace
|
Adds a utility conversion class for rewriting op interface instances targeting LLVM dialect.
Adds a utility conversion class for rewriting op interface instances targeting LLVM dialect.
Adds a utility conversion class for rewriting op interface instances targeting LLVM dialect.