-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir] Adopt ConvertToLLVMPatternInterface
GpuToLLVMConversionPass to align with convert-to-llvm
#73761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Mehdi Amini (joker-eph) ChangesThis is a follow-up to the introduction of Full diff: https://github.com/llvm/llvm-project/pull/73761.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
index 2eddf52d7abc520..73deef49c4175d3 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
@@ -22,6 +22,10 @@ namespace mlir {
/// implementing `ConvertToLLVMPatternInterface`.
std::unique_ptr<Pass> createConvertToLLVMPass();
+/// Register the extension that will load dependent dialects for LLVM
+/// conversion. This is useful to implement a pass similar to "convert-to-llvm".
+void registerConvertToLLVMDependentDialectLoading(DialectRegistry ®istry);
+
} // namespace mlir
#endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index a90e557b1fdbd9c..6135117348a5b86 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -124,6 +124,11 @@ class ConvertToLLVMPass
} // namespace
+void mlir::registerConvertToLLVMDependentDialectLoading(
+ DialectRegistry ®istry) {
+ registry.addExtensions<LoadDependentDialectExtension>();
+}
+
std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
return std::make_unique<ConvertToLLVMPass>();
}
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 2da97c20e9c984e..fec15dceeba9693 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -18,6 +18,7 @@
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
#include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -33,11 +34,14 @@
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"
+#define DEBUG_TYPE "gpu-to-llvm"
+
namespace mlir {
#define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
#include "mlir/Conversion/Passes.h.inc"
@@ -48,12 +52,14 @@ using namespace mlir;
static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
namespace {
-
class GpuToLLVMConversionPass
: public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
public:
using Base::Base;
-
+ void getDependentDialects(DialectRegistry ®istry) const final {
+ Base::getDependentDialects(registry);
+ registerConvertToLLVMDependentDialectLoading(registry);
+ }
// Run the dialect converter on the module.
void runOnOperation() override;
};
@@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
} // namespace
void GpuToLLVMConversionPass::runOnOperation() {
- LowerToLLVMOptions options(&getContext());
+ MLIRContext *context = &getContext();
+ SymbolTable symbolTable = SymbolTable(getOperation());
+ LowerToLLVMOptions options(context);
options.useBarePtrCallConv = hostBarePtrCallConv;
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+ target.addLegalDialect<LLVM::LLVMDialect>();
+ LLVMTypeConverter converter(context, options);
+
+ // Populate all patterns from all dialects that implement the
+ // `ConvertToLLVMPatternInterface` interface.
+ for (Dialect *dialect : context->getLoadedDialects()) {
+ auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+ if (!iface)
+ continue;
+ iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
+ }
- LLVMTypeConverter converter(&getContext(), options);
- RewritePatternSet patterns(&getContext());
- LLVMConversionTarget target(getContext());
-
- SymbolTable symbolTable = SymbolTable(getOperation());
// Preserve GPU modules if they have target attributes.
target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
[](gpu::GPUModuleOp module) -> bool {
@@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
!module.getTargetsAttr().empty());
});
- mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
- mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+ // These aren't covered by the ConvertToLLVMPatternInterface right now.
populateVectorToLLVMConversionPatterns(converter, patterns);
populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
- populateFuncToLLVMConversionPatterns(converter, patterns);
populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
target);
populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
…to align with `convert-to-llvm` This is a follow-up to the introduction of `convert-to-llvm`: it is supposed to be a unifying pass through the `ConvertToLLVMPatternInterface`, but some specific conversion (like the GPU target) aren't vanilla LLVM target. Instead they need extra customizations that are specific to LLVM-on-GPUs and our custom runtime wrappers. This change make the GpuToLLVMConversionPass just as pluggable as the `convert-to-llvm` by using the same mechanism.
9ce6d7f
to
be17e0b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
@@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() { | |||
!module.getTargetsAttr().empty()); | |||
}); | |||
|
|||
mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns); | |||
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns); | |||
// These aren't covered by the ConvertToLLVMPatternInterface right now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out of curiosity - why these passes are deleted and not covered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I didn't follow what you mean by deleted?
The populateXXXXToLLVMConversionPatterns
that I didn't remove are the one for which we haven't implemented the ConvertToLLVMPatternInterface
.
The one I removed (populateControlFlowToLLVMConversionPatterns
and populateFuncToLLVMConversionPatterns
) will be handled line 604 above (and only when these dialects are loaded: we only pay what we may need :) )
This is a follow-up to the introduction of
convert-to-llvm
: it is supposed to be a unifying pass through theConvertToLLVMPatternInterface
, but some specific conversion (like the GPU target) aren't vanilla LLVM target. Instead they need extra customizations that are specific to LLVM-on-GPUs and our custom runtime wrappers.This change make the GpuToLLVMConversionPass just as pluggable as the
convert-to-llvm
by using the same mechanism.