-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Add callback functions for ModuleToObject #116007
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Add callback functions for ModuleToObject #116007
Conversation
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir-llvm Author: Zichen Lu (MikaOvO) ChangesIn ModuleToObject flow, users may want to add some callback functions invoked with LLVM IR/ISA for debugging or other purposes. Full diff: https://github.com/llvm/llvm-project/pull/116007.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 6d7cb5ca7a7f81..d4b16a1de8eddc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
#include "mlir/IR/Attributes.h"
+#include "llvm/IR/Module.h"
namespace llvm {
class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
/// table.
SymbolTable *getSymbolTable() const;
+ /// Returns the callback invoked with the initial LLVM IR for the device
+ /// module.
+ function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module
+ /// after linking the device libraries.
+ function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> getISACallback() const;
+
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -90,7 +111,11 @@ class TargetOptions {
TypeID typeID, StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
/// being serialized.
function_ref<SymbolTable *()> getSymbolTableCallback;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
TypeID typeID;
};
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index e40d7e9a43dd6b..07fc55b41ae9c5 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,8 +29,13 @@ class ModuleTranslation;
/// operations being transformed must be translatable into LLVM IR.
class ModuleToObject {
public:
- ModuleToObject(Operation &module, StringRef triple, StringRef chip,
- StringRef features = {}, int optLevel = 3);
+ ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip,
+ StringRef features = {}, int optLevel = 3,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
/// Optimization level.
int optLevel;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
/// The TargetMachine created for the given Triple, if available.
/// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..d62ea72dcea2f6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
TargetOptions::TargetOptions(
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
- cmdOptions, compilationTarget, getSymbolTableCallback) {}
+ cmdOptions, compilationTarget, getSymbolTableCallback,
+ initialLlvmIRCallback, linkedLlvmIRCallback,
+ optimizedLlvmIRCallback, isaCallback) {}
TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
- getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
+ getSymbolTableCallback(getSymbolTableCallback),
+ initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback), typeID(typeID) {}
TypeID TargetOptions::getTypeID() const { return typeID; }
@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
+function_ref<void(llvm::Module &)>
+TargetOptions::getInitialLlvmIRCallback() const {
+ return initialLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getLinkedLlvmIRCallback() const {
+ return linkedLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getOptimizedLlvmIRCallback() const {
+ return optimizedLlvmIRCallback;
+}
+
+function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+ return isaCallback;
+}
+
CompilationTarget TargetOptions::getCompilationTarget() const {
return compilationTarget;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 77391341adaad2..8a4fa3970951c9 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,10 +34,17 @@
using namespace mlir;
using namespace mlir::LLVM;
-ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
- StringRef chip, StringRef features, int optLevel)
+ModuleToObject::ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip, StringRef features,
+ int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
- optLevel(optLevel) {}
+ optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback) {}
ModuleToObject::~ModuleToObject() = default;
@@ -215,6 +222,10 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
+ if (initialLlvmIRCallback) {
+ initialLlvmIRCallback(*llvmModule);
+ }
+
// Link bitcode files.
handleModulePreLink(*llvmModule);
{
@@ -227,10 +238,18 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
+ if (linkedLlvmIRCallback) {
+ linkedLlvmIRCallback(*llvmModule);
+ }
+
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
+ if (optimizedLlvmIRCallback) {
+ optimizedLlvmIRCallback(*llvmModule);
+ }
+
// Return the serialized object.
return moduleToObject(*llvmModule);
}
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 69602af8563aa0..d47fb856cb7222 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
getOperation().emitError() << "Failed translating the module to ISA.";
return std::nullopt;
}
+ if (isaCallback) {
+ isaCallback(serializedISA.value());
+ }
#define DEBUG_TYPE "serialize-to-isa"
LLVM_DEBUG({
llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
|
@llvm/pr-subscribers-mlir Author: Zichen Lu (MikaOvO) ChangesIn ModuleToObject flow, users may want to add some callback functions invoked with LLVM IR/ISA for debugging or other purposes. Full diff: https://github.com/llvm/llvm-project/pull/116007.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 6d7cb5ca7a7f81..d4b16a1de8eddc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
#include "mlir/IR/Attributes.h"
+#include "llvm/IR/Module.h"
namespace llvm {
class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
/// table.
SymbolTable *getSymbolTable() const;
+ /// Returns the callback invoked with the initial LLVM IR for the device
+ /// module.
+ function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module
+ /// after linking the device libraries.
+ function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> getISACallback() const;
+
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -90,7 +111,11 @@ class TargetOptions {
TypeID typeID, StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
/// being serialized.
function_ref<SymbolTable *()> getSymbolTableCallback;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
TypeID typeID;
};
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index e40d7e9a43dd6b..07fc55b41ae9c5 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,8 +29,13 @@ class ModuleTranslation;
/// operations being transformed must be translatable into LLVM IR.
class ModuleToObject {
public:
- ModuleToObject(Operation &module, StringRef triple, StringRef chip,
- StringRef features = {}, int optLevel = 3);
+ ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip,
+ StringRef features = {}, int optLevel = 3,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
/// Optimization level.
int optLevel;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
/// The TargetMachine created for the given Triple, if available.
/// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..d62ea72dcea2f6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
TargetOptions::TargetOptions(
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
- cmdOptions, compilationTarget, getSymbolTableCallback) {}
+ cmdOptions, compilationTarget, getSymbolTableCallback,
+ initialLlvmIRCallback, linkedLlvmIRCallback,
+ optimizedLlvmIRCallback, isaCallback) {}
TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
- getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
+ getSymbolTableCallback(getSymbolTableCallback),
+ initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback), typeID(typeID) {}
TypeID TargetOptions::getTypeID() const { return typeID; }
@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
+function_ref<void(llvm::Module &)>
+TargetOptions::getInitialLlvmIRCallback() const {
+ return initialLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getLinkedLlvmIRCallback() const {
+ return linkedLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getOptimizedLlvmIRCallback() const {
+ return optimizedLlvmIRCallback;
+}
+
+function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+ return isaCallback;
+}
+
CompilationTarget TargetOptions::getCompilationTarget() const {
return compilationTarget;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 77391341adaad2..8a4fa3970951c9 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,10 +34,17 @@
using namespace mlir;
using namespace mlir::LLVM;
-ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
- StringRef chip, StringRef features, int optLevel)
+ModuleToObject::ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip, StringRef features,
+ int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
- optLevel(optLevel) {}
+ optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback) {}
ModuleToObject::~ModuleToObject() = default;
@@ -215,6 +222,10 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
+ if (initialLlvmIRCallback) {
+ initialLlvmIRCallback(*llvmModule);
+ }
+
// Link bitcode files.
handleModulePreLink(*llvmModule);
{
@@ -227,10 +238,18 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
+ if (linkedLlvmIRCallback) {
+ linkedLlvmIRCallback(*llvmModule);
+ }
+
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
+ if (optimizedLlvmIRCallback) {
+ optimizedLlvmIRCallback(*llvmModule);
+ }
+
// Return the serialized object.
return moduleToObject(*llvmModule);
}
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 69602af8563aa0..d47fb856cb7222 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
getOperation().emitError() << "Failed translating the module to ISA.";
return std::nullopt;
}
+ if (isaCallback) {
+ isaCallback(serializedISA.value());
+ }
#define DEBUG_TYPE "serialize-to-isa"
LLVM_DEBUG({
llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks reasonable but I'm wondering if this could in any way be tested?
if (initialLlvmIRCallback) { | ||
initialLlvmIRCallback(*llvmModule); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (initialLlvmIRCallback) { | |
initialLlvmIRCallback(*llvmModule); | |
} | |
if (initialLlvmIRCallback) | |
initialLlvmIRCallback(*llvmModule); |
The same applies to other cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The primary indirect user of ModuelToObject
in tree is gpu-module-to-binary
and this change doesn't touch that pass, so this contribution doesn't appear to have any uses upstream. I have the following questions:
- How is this going to be used?
- Why is this needed?
Also, this requires unit tests.
Yes we probably can use Note that downstream, we have uses of |
I don't think
Are these uses through the target attributes (NVVM, ROCDL)? If not, downstream can always subclass Also, FWIW the main reason for my block is the lack of unit tests. |
mlir/lib/Target/LLVM/NVVM/Target.cpp
Outdated
@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) { | |||
getOperation().emitError() << "Failed translating the module to ISA."; | |||
return std::nullopt; | |||
} | |||
if (isaCallback) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have also dump-sass
. Should we also add a callback for that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we dump the sass result ourselves after getting it?
I think the reason for adding callback functions here is that we cannot obtain the intermediate results (LLVM IR, etc.) by subclass SerializeGPUModuleBase
(maybe we can override ModuleToObject::run()
, but this will cause code duplication, which is not a good solution).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure we can 😀 But this means code replication.
This path calls nvdisasm on cubin and print sass. I think it's better to leverage this mechanism, and all the downstream compilers use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't quite get it? The callback would return the same thing as the serializer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This means code replication. This path calls nvdisasm on cubin and print sass. I think it's better to leverage this mechanism, and all the downstream compilers use it.
I think you are right. But in my current understanding, sass
is a concept that only exists in nvvm
. So we can not name a callback sassCallBack
. Similarly, we cannot name a callback ptxCallback
, but we can name it isaCallback
.
we can integrate this work with the fwiw, I implemented something similar for this case. I wasn't aware of this PR, and I closed my PR (linked below), but it demonstrates how to handle the wiring between tests and |
@MikaOvO could you add some unit tests here https://github.com/llvm/llvm-project/tree/main/mlir/unittests/Target/LLVM ? It could be in SerializeNVVMTarget.cpp |
Yes, I'm trying to do this today, thanks! |
Hi all. Thanks for your comments! Please let me answer your questions based on my understanding.
We can use like this: auto callback = [](llvm::Module &module) {
// do something like module.dump();
}
auto callback2 = [](llvm::StringRef isa) {
// do something like llvm::outs() << isa;
}
gpu::TargetOptions options(..., callback, {}, {}, callback2)
In |
a54b04b
to
a6129d8
Compare
I have already added a unit test in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the tests!
a6129d8
to
bec45ed
Compare
std::ostringstream oss; | ||
llvm::raw_os_ostream ros(oss); | ||
ros.flush(); | ||
module.print(ros, nullptr); | ||
optimizedLLVMIR = oss.str(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The ros.flush()
is inserted before the module.print()
, that does not seem right...
This should work instead:
std::ostringstream oss; | |
llvm::raw_os_ostream ros(oss); | |
ros.flush(); | |
module.print(ros, nullptr); | |
optimizedLLVMIR = oss.str(); | |
llvm::raw_string_ostream rso(optimizedLLVMIR); | |
module.print(rso, nullptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why didn't you just apply the diff I wrote? It's 2 lines instead of 5...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, let me fix it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed, thanks! (Sorry I didn't look the code carefully.
bec45ed
to
f67139d
Compare
f67139d
to
62707ce
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/6557 Here is the relevant piece of the build log for the reference
|
This reverts commit 2153672.
Reverts #116007 Bot is broken.
I had to revert: the callback isn't invoked. |
OK, thanks! |
Here is the [merged MR](#116007) which caused a failure and [was reverted](#116811). Thanks to @joker-eph for the help, I fix it (miss constructing `ModuleObject` with callback functions in `mlir/lib/Target/LLVM/NVVM/Target.cpp`) and split unit tests from origin test which don't need `ptxas` to make the test runs more widely.
In ModuleToObject flow, users may want to add some callback functions invoked with LLVM IR/ISA for debugging or other purposes.