Skip to content

[mlir] Adopt ConvertToLLVMPatternInterface GpuToLLVMConversionPass to align with convert-to-llvm #73761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 29, 2023

Conversation

joker-eph
Copy link
Collaborator

This is a follow-up to the introduction of convert-to-llvm: it is supposed to be a unifying pass through the ConvertToLLVMPatternInterface, but some specific conversion (like the GPU target) aren't vanilla LLVM target. Instead they need extra customizations that are specific to LLVM-on-GPUs and our custom runtime wrappers.
This change make the GpuToLLVMConversionPass just as pluggable as the convert-to-llvm by using the same mechanism.

@llvmbot
Copy link
Member

llvmbot commented Nov 29, 2023

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Mehdi Amini (joker-eph)

Changes

This is a follow-up to the introduction of convert-to-llvm: it is supposed to be a unifying pass through the ConvertToLLVMPatternInterface, but some specific conversion (like the GPU target) aren't vanilla LLVM target. Instead they need extra customizations that are specific to LLVM-on-GPUs and our custom runtime wrappers.
This change make the GpuToLLVMConversionPass just as pluggable as the convert-to-llvm by using the same mechanism.


Full diff: https://github.com/llvm/llvm-project/pull/73761.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h (+4)
  • (modified) mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp (+5)
  • (modified) mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp (+25-11)
diff --git a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
index 2eddf52d7abc520..73deef49c4175d3 100644
--- a/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
+++ b/mlir/include/mlir/Conversion/ConvertToLLVM/ToLLVMPass.h
@@ -22,6 +22,10 @@ namespace mlir {
 /// implementing `ConvertToLLVMPatternInterface`.
 std::unique_ptr<Pass> createConvertToLLVMPass();
 
+/// Register the extension that will load dependent dialects for LLVM
+/// conversion. This is useful to implement a pass similar to "convert-to-llvm".
+void registerConvertToLLVMDependentDialectLoading(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_CONVERTTOLLVM_TOLLVM_PASS_H
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index a90e557b1fdbd9c..6135117348a5b86 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -124,6 +124,11 @@ class ConvertToLLVMPass
 
 } // namespace
 
+void mlir::registerConvertToLLVMDependentDialectLoading(
+    DialectRegistry &registry) {
+  registry.addExtensions<LoadDependentDialectExtension>();
+}
+
 std::unique_ptr<Pass> mlir::createConvertToLLVMPass() {
   return std::make_unique<ConvertToLLVMPass>();
 }
diff --git a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
index 2da97c20e9c984e..fec15dceeba9693 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
@@ -18,6 +18,7 @@
 #include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
 #include "mlir/Conversion/AsyncToLLVM/AsyncToLLVM.h"
 #include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
 #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
 #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
@@ -33,11 +34,14 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/Conversion/ConvertToLLVM/ToLLVMPass.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Error.h"
 #include "llvm/Support/FormatVariadic.h"
 
+#define DEBUG_TYPE "gpu-to-llvm"
+
 namespace mlir {
 #define GEN_PASS_DEF_GPUTOLLVMCONVERSIONPASS
 #include "mlir/Conversion/Passes.h.inc"
@@ -48,12 +52,14 @@ using namespace mlir;
 static constexpr const char *kGpuBinaryStorageSuffix = "_gpubin_cst";
 
 namespace {
-
 class GpuToLLVMConversionPass
     : public impl::GpuToLLVMConversionPassBase<GpuToLLVMConversionPass> {
 public:
   using Base::Base;
-
+  void getDependentDialects(DialectRegistry &registry) const final {
+    Base::getDependentDialects(registry);
+    registerConvertToLLVMDependentDialectLoading(registry);
+  }
   // Run the dialect converter on the module.
   void runOnOperation() override;
 };
@@ -580,14 +586,24 @@ DECLARE_CONVERT_OP_TO_GPU_RUNTIME_CALL_PATTERN(SetCsrPointersOp)
 } // namespace
 
 void GpuToLLVMConversionPass::runOnOperation() {
-  LowerToLLVMOptions options(&getContext());
+  MLIRContext *context = &getContext();
+  SymbolTable symbolTable = SymbolTable(getOperation());
+  LowerToLLVMOptions options(context);
   options.useBarePtrCallConv = hostBarePtrCallConv;
+  RewritePatternSet patterns(context);
+  ConversionTarget target(*context);
+  target.addLegalDialect<LLVM::LLVMDialect>();
+  LLVMTypeConverter converter(context, options);
+
+  // Populate all patterns from all dialects that implement the
+  // `ConvertToLLVMPatternInterface` interface.
+  for (Dialect *dialect : context->getLoadedDialects()) {
+    auto iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
+    if (!iface)
+      continue;
+    iface->populateConvertToLLVMConversionPatterns(target, converter, patterns);
+  }
 
-  LLVMTypeConverter converter(&getContext(), options);
-  RewritePatternSet patterns(&getContext());
-  LLVMConversionTarget target(getContext());
-
-  SymbolTable symbolTable = SymbolTable(getOperation());
   // Preserve GPU modules if they have target attributes.
   target.addDynamicallyLegalOp<gpu::GPUModuleOp>(
       [](gpu::GPUModuleOp module) -> bool {
@@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
                 !module.getTargetsAttr().empty());
       });
 
-  mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
-  mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
+  // These aren't covered by the ConvertToLLVMPatternInterface right now.
   populateVectorToLLVMConversionPatterns(converter, patterns);
   populateFinalizeMemRefToLLVMConversionPatterns(converter, patterns);
-  populateFuncToLLVMConversionPatterns(converter, patterns);
   populateAsyncStructuralTypeConversionsAndLegality(converter, patterns,
                                                     target);
   populateGpuToLLVMConversionPatterns(converter, patterns, gpuBinaryAnnotation,

Copy link

github-actions bot commented Nov 29, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

…to align with `convert-to-llvm`

This is a follow-up to the introduction of `convert-to-llvm`: it is supposed
to be a unifying pass through the `ConvertToLLVMPatternInterface`, but some
specific conversion (like the GPU target) aren't vanilla LLVM target. Instead
they need extra customizations that are specific to LLVM-on-GPUs and our
custom runtime wrappers.
This change make the GpuToLLVMConversionPass just as pluggable as the
`convert-to-llvm` by using the same mechanism.
@joker-eph joker-eph force-pushed the GpuToLLVMConversionPass branch from 9ce6d7f to be17e0b Compare November 29, 2023 10:19
Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good

@@ -605,11 +621,9 @@ void GpuToLLVMConversionPass::runOnOperation() {
!module.getTargetsAttr().empty());
});

mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
// These aren't covered by the ConvertToLLVMPatternInterface right now.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

out of curiosity - why these passes are deleted and not covered?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't follow what you mean by deleted?

The populateXXXXToLLVMConversionPatterns that I didn't remove are the one for which we haven't implemented the ConvertToLLVMPatternInterface.
The one I removed (populateControlFlowToLLVMConversionPatterns and populateFuncToLLVMConversionPatterns) will be handled line 604 above (and only when these dialects are loaded: we only pay what we may need :) )

@joker-eph joker-eph merged commit 9e7b6f4 into llvm:main Nov 29, 2023
@joker-eph joker-eph deleted the GpuToLLVMConversionPass branch November 29, 2023 19:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants