Skip to content

[MLIR] Add callback functions for ModuleToObject #116007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024

Conversation

MikaOvO
Copy link
Contributor

@MikaOvO MikaOvO commented Nov 13, 2024

In ModuleToObject flow, users may want to add some callback functions invoked with LLVM IR/ISA for debugging or other purposes.

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir-llvm

Author: Zichen Lu (MikaOvO)

Changes

In ModuleToObject flow, users may want to add some callback functions invoked with LLVM IR/ISA for debugging or other purposes.


Full diff: https://github.com/llvm/llvm-project/pull/116007.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h (+42-2)
  • (modified) mlir/include/mlir/Target/LLVM/ModuleToObject.h (+22-2)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+37-4)
  • (modified) mlir/lib/Target/LLVM/ModuleToObject.cpp (+22-3)
  • (modified) mlir/lib/Target/LLVM/NVVM/Target.cpp (+3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 6d7cb5ca7a7f81..d4b16a1de8eddc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
 
 #include "mlir/IR/Attributes.h"
+#include "llvm/IR/Module.h"
 
 namespace llvm {
 class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
       StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
       StringRef cmdOptions = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
-      function_ref<SymbolTable *()> getSymbolTableCallback = {});
+      function_ref<SymbolTable *()> getSymbolTableCallback = {},
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
   /// table.
   SymbolTable *getSymbolTable() const;
 
+  /// Returns the callback invoked with the initial LLVM IR for the device
+  /// module.
+  function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+
+  /// Returns the callback invoked with LLVM IR for the device module
+  /// after linking the device libraries.
+  function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+
+  /// Returns the callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+
+  /// Returns the callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> getISACallback() const;
+
   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
   static CompilationTarget getDefaultCompilationTarget();
 
@@ -90,7 +111,11 @@ class TargetOptions {
       TypeID typeID, StringRef toolkitPath = {},
       ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
-      function_ref<SymbolTable *()> getSymbolTableCallback = {});
+      function_ref<SymbolTable *()> getSymbolTableCallback = {},
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
   /// being serialized.
   function_ref<SymbolTable *()> getSymbolTableCallback;
 
+  /// Callback invoked with the initial LLVM IR for the device module.
+  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// linking the device libraries.
+  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+  /// Callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> isaCallback;
+
 private:
   TypeID typeID;
 };
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index e40d7e9a43dd6b..07fc55b41ae9c5 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,8 +29,13 @@ class ModuleTranslation;
 /// operations being transformed must be translatable into LLVM IR.
 class ModuleToObject {
 public:
-  ModuleToObject(Operation &module, StringRef triple, StringRef chip,
-                 StringRef features = {}, int optLevel = 3);
+  ModuleToObject(
+      Operation &module, StringRef triple, StringRef chip,
+      StringRef features = {}, int optLevel = 3,
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
   virtual ~ModuleToObject();
 
   /// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
   /// Optimization level.
   int optLevel;
 
+  /// Callback invoked with the initial LLVM IR for the device module.
+  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// linking the device libraries.
+  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+  /// Callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> isaCallback;
+
 private:
   /// The TargetMachine created for the given Triple, if available.
   /// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..d62ea72dcea2f6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
 TargetOptions::TargetOptions(
     StringRef toolkitPath, ArrayRef<std::string> linkFiles,
     StringRef cmdOptions, CompilationTarget compilationTarget,
-    function_ref<SymbolTable *()> getSymbolTableCallback)
+    function_ref<SymbolTable *()> getSymbolTableCallback,
+    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
-                    cmdOptions, compilationTarget, getSymbolTableCallback) {}
+                    cmdOptions, compilationTarget, getSymbolTableCallback,
+                    initialLlvmIRCallback, linkedLlvmIRCallback,
+                    optimizedLlvmIRCallback, isaCallback) {}
 
 TargetOptions::TargetOptions(
     TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
     StringRef cmdOptions, CompilationTarget compilationTarget,
-    function_ref<SymbolTable *()> getSymbolTableCallback)
+    function_ref<SymbolTable *()> getSymbolTableCallback,
+    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
       cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
-      getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
+      getSymbolTableCallback(getSymbolTableCallback),
+      initialLlvmIRCallback(initialLlvmIRCallback),
+      linkedLlvmIRCallback(linkedLlvmIRCallback),
+      optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+      isaCallback(isaCallback), typeID(typeID) {}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
   return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
 }
 
+function_ref<void(llvm::Module &)>
+TargetOptions::getInitialLlvmIRCallback() const {
+  return initialLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getLinkedLlvmIRCallback() const {
+  return linkedLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getOptimizedLlvmIRCallback() const {
+  return optimizedLlvmIRCallback;
+}
+
+function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+  return isaCallback;
+}
+
 CompilationTarget TargetOptions::getCompilationTarget() const {
   return compilationTarget;
 }
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 77391341adaad2..8a4fa3970951c9 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,10 +34,17 @@
 using namespace mlir;
 using namespace mlir::LLVM;
 
-ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
-                               StringRef chip, StringRef features, int optLevel)
+ModuleToObject::ModuleToObject(
+    Operation &module, StringRef triple, StringRef chip, StringRef features,
+    int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : module(module), triple(triple), chip(chip), features(features),
-      optLevel(optLevel) {}
+      optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
+      linkedLlvmIRCallback(linkedLlvmIRCallback),
+      optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+      isaCallback(isaCallback) {}
 
 ModuleToObject::~ModuleToObject() = default;
 
@@ -215,6 +222,10 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
   }
   setDataLayoutAndTriple(*llvmModule);
 
+  if (initialLlvmIRCallback) {
+    initialLlvmIRCallback(*llvmModule);
+  }
+
   // Link bitcode files.
   handleModulePreLink(*llvmModule);
   {
@@ -227,10 +238,18 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
     handleModulePostLink(*llvmModule);
   }
 
+  if (linkedLlvmIRCallback) {
+    linkedLlvmIRCallback(*llvmModule);
+  }
+
   // Optimize the module.
   if (failed(optimizeModule(*llvmModule, optLevel)))
     return std::nullopt;
 
+  if (optimizedLlvmIRCallback) {
+    optimizedLlvmIRCallback(*llvmModule);
+  }
+
   // Return the serialized object.
   return moduleToObject(*llvmModule);
 }
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 69602af8563aa0..d47fb856cb7222 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
     getOperation().emitError() << "Failed translating the module to ISA.";
     return std::nullopt;
   }
+  if (isaCallback) {
+    isaCallback(serializedISA.value());
+  }
 #define DEBUG_TYPE "serialize-to-isa"
   LLVM_DEBUG({
     llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";

@llvmbot
Copy link
Member

llvmbot commented Nov 13, 2024

@llvm/pr-subscribers-mlir

Author: Zichen Lu (MikaOvO)

Changes

In ModuleToObject flow, users may want to add some callback functions invoked with LLVM IR/ISA for debugging or other purposes.


Full diff: https://github.com/llvm/llvm-project/pull/116007.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h (+42-2)
  • (modified) mlir/include/mlir/Target/LLVM/ModuleToObject.h (+22-2)
  • (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+37-4)
  • (modified) mlir/lib/Target/LLVM/ModuleToObject.cpp (+22-3)
  • (modified) mlir/lib/Target/LLVM/NVVM/Target.cpp (+3)
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 6d7cb5ca7a7f81..d4b16a1de8eddc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -14,6 +14,7 @@
 #define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
 
 #include "mlir/IR/Attributes.h"
+#include "llvm/IR/Module.h"
 
 namespace llvm {
 class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
       StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
       StringRef cmdOptions = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
-      function_ref<SymbolTable *()> getSymbolTableCallback = {});
+      function_ref<SymbolTable *()> getSymbolTableCallback = {},
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
 
   /// Returns the typeID.
   TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
   /// table.
   SymbolTable *getSymbolTable() const;
 
+  /// Returns the callback invoked with the initial LLVM IR for the device
+  /// module.
+  function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+
+  /// Returns the callback invoked with LLVM IR for the device module
+  /// after linking the device libraries.
+  function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+
+  /// Returns the callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+
+  /// Returns the callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> getISACallback() const;
+
   /// Returns the default compilation target: `CompilationTarget::Fatbin`.
   static CompilationTarget getDefaultCompilationTarget();
 
@@ -90,7 +111,11 @@ class TargetOptions {
       TypeID typeID, StringRef toolkitPath = {},
       ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
       CompilationTarget compilationTarget = getDefaultCompilationTarget(),
-      function_ref<SymbolTable *()> getSymbolTableCallback = {});
+      function_ref<SymbolTable *()> getSymbolTableCallback = {},
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
 
   /// Path to the target toolkit.
   std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
   /// being serialized.
   function_ref<SymbolTable *()> getSymbolTableCallback;
 
+  /// Callback invoked with the initial LLVM IR for the device module.
+  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// linking the device libraries.
+  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+  /// Callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> isaCallback;
+
 private:
   TypeID typeID;
 };
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index e40d7e9a43dd6b..07fc55b41ae9c5 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,8 +29,13 @@ class ModuleTranslation;
 /// operations being transformed must be translatable into LLVM IR.
 class ModuleToObject {
 public:
-  ModuleToObject(Operation &module, StringRef triple, StringRef chip,
-                 StringRef features = {}, int optLevel = 3);
+  ModuleToObject(
+      Operation &module, StringRef triple, StringRef chip,
+      StringRef features = {}, int optLevel = 3,
+      function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+      function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+      function_ref<void(StringRef)> isaCallback = {});
   virtual ~ModuleToObject();
 
   /// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
   /// Optimization level.
   int optLevel;
 
+  /// Callback invoked with the initial LLVM IR for the device module.
+  function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// linking the device libraries.
+  function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+  /// Callback invoked with LLVM IR for the device module after
+  /// LLVM optimizations but before codegen.
+  function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+  /// Callback invoked with the target ISA for the device,
+  /// for example PTX assembly.
+  function_ref<void(StringRef)> isaCallback;
+
 private:
   /// The TargetMachine created for the given Triple, if available.
   /// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..d62ea72dcea2f6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
 TargetOptions::TargetOptions(
     StringRef toolkitPath, ArrayRef<std::string> linkFiles,
     StringRef cmdOptions, CompilationTarget compilationTarget,
-    function_ref<SymbolTable *()> getSymbolTableCallback)
+    function_ref<SymbolTable *()> getSymbolTableCallback,
+    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
-                    cmdOptions, compilationTarget, getSymbolTableCallback) {}
+                    cmdOptions, compilationTarget, getSymbolTableCallback,
+                    initialLlvmIRCallback, linkedLlvmIRCallback,
+                    optimizedLlvmIRCallback, isaCallback) {}
 
 TargetOptions::TargetOptions(
     TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
     StringRef cmdOptions, CompilationTarget compilationTarget,
-    function_ref<SymbolTable *()> getSymbolTableCallback)
+    function_ref<SymbolTable *()> getSymbolTableCallback,
+    function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
       cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
-      getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
+      getSymbolTableCallback(getSymbolTableCallback),
+      initialLlvmIRCallback(initialLlvmIRCallback),
+      linkedLlvmIRCallback(linkedLlvmIRCallback),
+      optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+      isaCallback(isaCallback), typeID(typeID) {}
 
 TypeID TargetOptions::getTypeID() const { return typeID; }
 
@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
   return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
 }
 
+function_ref<void(llvm::Module &)>
+TargetOptions::getInitialLlvmIRCallback() const {
+  return initialLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getLinkedLlvmIRCallback() const {
+  return linkedLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getOptimizedLlvmIRCallback() const {
+  return optimizedLlvmIRCallback;
+}
+
+function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+  return isaCallback;
+}
+
 CompilationTarget TargetOptions::getCompilationTarget() const {
   return compilationTarget;
 }
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 77391341adaad2..8a4fa3970951c9 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,10 +34,17 @@
 using namespace mlir;
 using namespace mlir::LLVM;
 
-ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
-                               StringRef chip, StringRef features, int optLevel)
+ModuleToObject::ModuleToObject(
+    Operation &module, StringRef triple, StringRef chip, StringRef features,
+    int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+    function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+    function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+    function_ref<void(StringRef)> isaCallback)
     : module(module), triple(triple), chip(chip), features(features),
-      optLevel(optLevel) {}
+      optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
+      linkedLlvmIRCallback(linkedLlvmIRCallback),
+      optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+      isaCallback(isaCallback) {}
 
 ModuleToObject::~ModuleToObject() = default;
 
@@ -215,6 +222,10 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
   }
   setDataLayoutAndTriple(*llvmModule);
 
+  if (initialLlvmIRCallback) {
+    initialLlvmIRCallback(*llvmModule);
+  }
+
   // Link bitcode files.
   handleModulePreLink(*llvmModule);
   {
@@ -227,10 +238,18 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
     handleModulePostLink(*llvmModule);
   }
 
+  if (linkedLlvmIRCallback) {
+    linkedLlvmIRCallback(*llvmModule);
+  }
+
   // Optimize the module.
   if (failed(optimizeModule(*llvmModule, optLevel)))
     return std::nullopt;
 
+  if (optimizedLlvmIRCallback) {
+    optimizedLlvmIRCallback(*llvmModule);
+  }
+
   // Return the serialized object.
   return moduleToObject(*llvmModule);
 }
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 69602af8563aa0..d47fb856cb7222 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
     getOperation().emitError() << "Failed translating the module to ISA.";
     return std::nullopt;
   }
+  if (isaCallback) {
+    isaCallback(serializedISA.value());
+  }
 #define DEBUG_TYPE "serialize-to-isa"
   LLVM_DEBUG({
     llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";

Copy link
Contributor

@Dinistro Dinistro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks reasonable but I'm wondering if this could in any way be tested?

Comment on lines 225 to 227
if (initialLlvmIRCallback) {
initialLlvmIRCallback(*llvmModule);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (initialLlvmIRCallback) {
initialLlvmIRCallback(*llvmModule);
}
if (initialLlvmIRCallback)
initialLlvmIRCallback(*llvmModule);

The same applies to other cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks!

@joker-eph joker-eph requested a review from fabianmcg November 13, 2024 11:00
Copy link
Contributor

@fabianmcg fabianmcg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The primary indirect user of ModuelToObject in tree is gpu-module-to-binary and this change doesn't touch that pass, so this contribution doesn't appear to have any uses upstream. I have the following questions:

  • How is this going to be used?
  • Why is this needed?

Also, this requires unit tests.

@joker-eph
Copy link
Collaborator

Yes we probably can use gpu-module-to-binary to test this feature.

Note that downstream, we have uses of ModuelToObject and other components totally independently of gpu-module-to-binary or the pass infra.

@fabianmcg
Copy link
Contributor

fabianmcg commented Nov 13, 2024

I don't think gpu-module-to-binary is worth modifying for testing this. The reason I brought it up is because I was mainly concerned to why this was good for upstream and gpu-module-to-binary is its primary indirect user. I was unaware of any uses downstream (outside GPU compilation).

Note that downstream, we have uses of ModuelToObject and other components totally independently of gpu-module-to-binary or the pass infra.

Are these uses through the target attributes (NVVM, ROCDL)? If not, downstream can always subclass ModuleToObject. That's why most of its methods are virtual, in case anyone needed to customize it. See for example https://github.com/ROCm/rocMLIR/blob/develop/mlir/lib/Target/Target.cpp#L72-L91

Also, FWIW the main reason for my block is the lack of unit tests.

@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
getOperation().emitError() << "Failed translating the module to ISA.";
return std::nullopt;
}
if (isaCallback) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have also dump-sass. Should we also add a callback for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we dump the sass result ourselves after getting it?

I think the reason for adding callback functions here is that we cannot obtain the intermediate results (LLVM IR, etc.) by subclass SerializeGPUModuleBase (maybe we can override ModuleToObject::run(), but this will cause code duplication, which is not a good solution).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure we can 😀 But this means code replication.

This path calls nvdisasm on cubin and print sass. I think it's better to leverage this mechanism, and all the downstream compilers use it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get it? The callback would return the same thing as the serializer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This means code replication. This path calls nvdisasm on cubin and print sass. I think it's better to leverage this mechanism, and all the downstream compilers use it.

I think you are right. But in my current understanding, sass is a concept that only exists in nvvm. So we can not name a callback sassCallBack. Similarly, we cannot name a callback ptxCallback, but we can name it isaCallback.

@grypp
Copy link
Member

grypp commented Nov 14, 2024

we can integrate this work with the GPUToNVVMPipeline to leverage existing tests.

fwiw, I implemented something similar for this case. I wasn't aware of this PR, and I closed my PR (linked below), but it demonstrates how to handle the wiring between tests and GPUToNVVMPipeline:
https://github.com/llvm/llvm-project/pull/116199/files

@fabianmcg
Copy link
Contributor

@MikaOvO could you add some unit tests here https://github.com/llvm/llvm-project/tree/main/mlir/unittests/Target/LLVM ? It could be in SerializeNVVMTarget.cpp

@MikaOvO
Copy link
Contributor Author

MikaOvO commented Nov 14, 2024

@MikaOvO could you add some unit tests here https://github.com/llvm/llvm-project/tree/main/mlir/unittests/Target/LLVM ? It could be in SerializeNVVMTarget.cpp

Yes, I'm trying to do this today, thanks!

@MikaOvO
Copy link
Contributor Author

MikaOvO commented Nov 14, 2024

Hi all. Thanks for your comments! Please let me answer your questions based on my understanding.

How is this going to be used?

We can use like this:

auto callback = [](llvm::Module &module) {
  // do something like module.dump();
}
auto callback2 = [](llvm::StringRef isa) {
  // do something like llvm::outs() << isa;
}

gpu::TargetOptions options(..., callback, {}, {}, callback2)

Why is this needed?

In ModuleToObject::run(), there is a continuous process of translating Module into LLVM IR. If we want to dump LLVM IR to debug without callback functions, we may need to subclass ModuleToObject and override run, then we may have code duplication.

@MikaOvO MikaOvO force-pushed the Implement_serialize_object_callbacks branch from a54b04b to a6129d8 Compare November 14, 2024 14:44
@MikaOvO
Copy link
Contributor Author

MikaOvO commented Nov 14, 2024

I have already added a unit test in SerializeNVVMTarget.cpp.

Copy link
Contributor

@fabianmcg fabianmcg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the tests!

@MikaOvO MikaOvO force-pushed the Implement_serialize_object_callbacks branch from a6129d8 to bec45ed Compare November 18, 2024 09:16
Comment on lines 197 to 201
std::ostringstream oss;
llvm::raw_os_ostream ros(oss);
ros.flush();
module.print(ros, nullptr);
optimizedLLVMIR = oss.str();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ros.flush() is inserted before the module.print(), that does not seem right...

This should work instead:

Suggested change
std::ostringstream oss;
llvm::raw_os_ostream ros(oss);
ros.flush();
module.print(ros, nullptr);
optimizedLLVMIR = oss.str();
llvm::raw_string_ostream rso(optimizedLLVMIR);
module.print(rso, nullptr);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why didn't you just apply the diff I wrote? It's 2 lines instead of 5...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, let me fix it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, thanks! (Sorry I didn't look the code carefully.

@MikaOvO MikaOvO force-pushed the Implement_serialize_object_callbacks branch from bec45ed to f67139d Compare November 19, 2024 01:28
@MikaOvO MikaOvO force-pushed the Implement_serialize_object_callbacks branch from f67139d to 62707ce Compare November 19, 2024 01:35
@joker-eph joker-eph merged commit 2153672 into llvm:main Nov 19, 2024
8 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented Nov 19, 2024

LLVM Buildbot has detected a new failure on builder mlir-nvidia-gcc7 running on mlir-nvidia while building mlir at step 6 "test-build-check-mlir-build-only-check-mlir".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/6557

Here is the relevant piece of the build log for the reference
Step 6 (test-build-check-mlir-build-only-check-mlir) failure: test (failure)
******************** TEST 'MLIR-Unit :: Target/LLVM/./MLIRTargetLLVMTests/3/6' FAILED ********************
Script(shard):
--
GTEST_OUTPUT=json:/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/tools/mlir/unittests/Target/LLVM/./MLIRTargetLLVMTests-MLIR-Unit-3746479-3-6.json GTEST_SHUFFLE=0 GTEST_TOTAL_SHARDS=6 GTEST_SHARD_INDEX=3 /vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/tools/mlir/unittests/Target/LLVM/./MLIRTargetLLVMTests
--

Script:
--
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.obj/tools/mlir/unittests/Target/LLVM/./MLIRTargetLLVMTests --gtest_filter=MLIRTargetLLVMNVVM.CallbackInvokedWithLLVMIRAndISA
--
/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp:210: Failure
Value of: !initialLLVMIR.empty()
  Actual: false
Expected: true


/vol/worker/mlir-nvidia/mlir-nvidia-gcc7/llvm.src/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp:210
Value of: !initialLLVMIR.empty()
  Actual: false
Expected: true



********************


@joker-eph
Copy link
Collaborator

I had to revert: the callback isn't invoked.
I think we should split the test with one per callback: that would make it so that we don't skip the test for the early callbacks when ptas is missing, making the test runs more widely.

@MikaOvO
Copy link
Contributor Author

MikaOvO commented Nov 19, 2024

I had to revert: the callback isn't invoked. I think we should split the test with one per callback: that would make it so that we don't skip the test for the early callbacks when ptas is missing, making the test runs more widely.

OK, thanks!

joker-eph pushed a commit that referenced this pull request Nov 20, 2024
Here is the [merged
MR](#116007) which caused a
failure and [was
reverted](#116811).

Thanks to @joker-eph for the help, I fix it (miss constructing
`ModuleObject` with callback functions in
`mlir/lib/Target/LLVM/NVVM/Target.cpp`) and split unit tests from origin
test which don't need `ptxas` to make the test runs more widely.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants