-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][fix] Add callback functions for ModuleToObject #116916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][fix] Add callback functions for ModuleToObject #116916
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Zichen Lu (MikaOvO) ChangesHere is the merged MR which caused a failure and was reverted. Thanks to @joker-eph for the help, I fix it (miss constructing Full diff: https://github.com/llvm/llvm-project/pull/116916.diff 7 Files Affected:
diff --git a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
index 6d7cb5ca7a7f81..d4b16a1de8eddc 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
+++ b/mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
@@ -14,6 +14,7 @@
#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
#include "mlir/IR/Attributes.h"
+#include "llvm/IR/Module.h"
namespace llvm {
class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Returns the typeID.
TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
/// table.
SymbolTable *getSymbolTable() const;
+ /// Returns the callback invoked with the initial LLVM IR for the device
+ /// module.
+ function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module
+ /// after linking the device libraries.
+ function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
+
+ /// Returns the callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> getISACallback() const;
+
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();
@@ -90,7 +111,11 @@ class TargetOptions {
TypeID typeID, StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
- function_ref<SymbolTable *()> getSymbolTableCallback = {});
+ function_ref<SymbolTable *()> getSymbolTableCallback = {},
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
/// Path to the target toolkit.
std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
/// being serialized.
function_ref<SymbolTable *()> getSymbolTableCallback;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
TypeID typeID;
};
diff --git a/mlir/include/mlir/Target/LLVM/ModuleToObject.h b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
index e40d7e9a43dd6b..07fc55b41ae9c5 100644
--- a/mlir/include/mlir/Target/LLVM/ModuleToObject.h
+++ b/mlir/include/mlir/Target/LLVM/ModuleToObject.h
@@ -29,8 +29,13 @@ class ModuleTranslation;
/// operations being transformed must be translatable into LLVM IR.
class ModuleToObject {
public:
- ModuleToObject(Operation &module, StringRef triple, StringRef chip,
- StringRef features = {}, int optLevel = 3);
+ ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip,
+ StringRef features = {}, int optLevel = 3,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
+ function_ref<void(StringRef)> isaCallback = {});
virtual ~ModuleToObject();
/// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
/// Optimization level.
int optLevel;
+ /// Callback invoked with the initial LLVM IR for the device module.
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// linking the device libraries.
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
+
+ /// Callback invoked with LLVM IR for the device module after
+ /// LLVM optimizations but before codegen.
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
+
+ /// Callback invoked with the target ISA for the device,
+ /// for example PTX assembly.
+ function_ref<void(StringRef)> isaCallback;
+
private:
/// The TargetMachine created for the given Triple, if available.
/// Accessible through `getOrCreateTargetMachine()`.
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 956877497d9338..d62ea72dcea2f6 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
TargetOptions::TargetOptions(
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
- cmdOptions, compilationTarget, getSymbolTableCallback) {}
+ cmdOptions, compilationTarget, getSymbolTableCallback,
+ initialLlvmIRCallback, linkedLlvmIRCallback,
+ optimizedLlvmIRCallback, isaCallback) {}
TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
- function_ref<SymbolTable *()> getSymbolTableCallback)
+ function_ref<SymbolTable *()> getSymbolTableCallback,
+ function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
- getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
+ getSymbolTableCallback(getSymbolTableCallback),
+ initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback), typeID(typeID) {}
TypeID TargetOptions::getTypeID() const { return typeID; }
@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}
+function_ref<void(llvm::Module &)>
+TargetOptions::getInitialLlvmIRCallback() const {
+ return initialLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getLinkedLlvmIRCallback() const {
+ return linkedLlvmIRCallback;
+}
+
+function_ref<void(llvm::Module &)>
+TargetOptions::getOptimizedLlvmIRCallback() const {
+ return optimizedLlvmIRCallback;
+}
+
+function_ref<void(StringRef)> TargetOptions::getISACallback() const {
+ return isaCallback;
+}
+
CompilationTarget TargetOptions::getCompilationTarget() const {
return compilationTarget;
}
diff --git a/mlir/lib/Target/LLVM/ModuleToObject.cpp b/mlir/lib/Target/LLVM/ModuleToObject.cpp
index 77391341adaad2..3f5b3d5e31864b 100644
--- a/mlir/lib/Target/LLVM/ModuleToObject.cpp
+++ b/mlir/lib/Target/LLVM/ModuleToObject.cpp
@@ -34,10 +34,17 @@
using namespace mlir;
using namespace mlir::LLVM;
-ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
- StringRef chip, StringRef features, int optLevel)
+ModuleToObject::ModuleToObject(
+ Operation &module, StringRef triple, StringRef chip, StringRef features,
+ int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
+ function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
+ function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
+ function_ref<void(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
- optLevel(optLevel) {}
+ optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
+ linkedLlvmIRCallback(linkedLlvmIRCallback),
+ optimizedLlvmIRCallback(optimizedLlvmIRCallback),
+ isaCallback(isaCallback) {}
ModuleToObject::~ModuleToObject() = default;
@@ -215,6 +222,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);
+ if (initialLlvmIRCallback)
+ initialLlvmIRCallback(*llvmModule);
+
// Link bitcode files.
handleModulePreLink(*llvmModule);
{
@@ -227,10 +237,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}
+ if (linkedLlvmIRCallback)
+ linkedLlvmIRCallback(*llvmModule);
+
// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;
+ if (optimizedLlvmIRCallback)
+ optimizedLlvmIRCallback(*llvmModule);
+
// Return the serialized object.
return moduleToObject(*llvmModule);
}
diff --git a/mlir/lib/Target/LLVM/NVVM/Target.cpp b/mlir/lib/Target/LLVM/NVVM/Target.cpp
index 69602af8563aa0..bca26e3a0e84a9 100644
--- a/mlir/lib/Target/LLVM/NVVM/Target.cpp
+++ b/mlir/lib/Target/LLVM/NVVM/Target.cpp
@@ -86,7 +86,11 @@ SerializeGPUModuleBase::SerializeGPUModuleBase(
Operation &module, NVVMTargetAttr target,
const gpu::TargetOptions &targetOptions)
: ModuleToObject(module, target.getTriple(), target.getChip(),
- target.getFeatures(), target.getO()),
+ target.getFeatures(), target.getO(),
+ targetOptions.getInitialLlvmIRCallback(),
+ targetOptions.getLinkedLlvmIRCallback(),
+ targetOptions.getOptimizedLlvmIRCallback(),
+ targetOptions.getISACallback()),
target(target), toolkitPath(targetOptions.getToolkitPath()),
fileList(targetOptions.getLinkFiles()) {
@@ -572,6 +576,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
getOperation().emitError() << "Failed translating the module to ISA.";
return std::nullopt;
}
+ if (isaCallback)
+ isaCallback(serializedISA.value());
+
#define DEBUG_TYPE "serialize-to-isa"
LLVM_DEBUG({
llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
diff --git a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
index 642aa045178095..eb0c5358ab3530 100644
--- a/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
@@ -156,3 +156,65 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) {
ASSERT_TRUE(!object->empty());
}
}
+
+// Test callback functions invoked with LLVM IR and ISA.
+TEST_F(MLIRTargetLLVMNVVM,
+ SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) {
+ if (!hasPtxas())
+ GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
+
+ MLIRContext context(registry);
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+
+ NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
+
+ auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
+ ASSERT_TRUE(!!serializer);
+
+ std::string initialLLVMIR;
+ auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ llvm::raw_string_ostream ros(initialLLVMIR);
+ module.print(ros, nullptr);
+ };
+
+ std::string linkedLLVMIR;
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ llvm::raw_string_ostream ros(linkedLLVMIR);
+ module.print(ros, nullptr);
+ };
+
+ std::string optimizedLLVMIR;
+ auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ llvm::raw_string_ostream ros(optimizedLLVMIR);
+ module.print(ros, nullptr);
+ };
+
+ std::string isaResult;
+ auto isaCallback = [&isaResult](llvm::StringRef isa) {
+ isaResult = isa.str();
+ };
+
+ gpu::TargetOptions options({}, {}, {}, gpu::CompilationTarget::Binary, {},
+ initialCallback, linkedCallback, optimizedCallback,
+ isaCallback);
+
+ for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
+ std::optional<SmallVector<char, 0>> object =
+ serializer.serializeToObject(gpuModule, options);
+
+ ASSERT_TRUE(object != std::nullopt);
+ ASSERT_TRUE(!object->empty());
+ ASSERT_TRUE(!initialLLVMIR.empty());
+ ASSERT_TRUE(!linkedLLVMIR.empty());
+ ASSERT_TRUE(!optimizedLLVMIR.empty());
+ ASSERT_TRUE(!isaResult.empty());
+
+ initialLLVMIR.clear();
+ linkedLLVMIR.clear();
+ optimizedLLVMIR.clear();
+ isaResult.clear();
+ }
+}
diff --git a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
index 0d4277ed2fdfdc..e8f2d4bd4b3f84 100644
--- a/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
+++ b/mlir/unittests/Target/LLVM/SerializeToLLVMBitcode.cpp
@@ -105,7 +105,9 @@ TargetAttrImpl::serializeToObject(Attribute attribute, Operation *module,
// Set a dummy attr to be retrieved by `createObject`.
module->setAttr("serialize_attr", UnitAttr::get(module->getContext()));
std::string targetTriple = llvm::sys::getProcessTriple();
- LLVM::ModuleToObject serializer(*module, targetTriple, "", "");
+ LLVM::ModuleToObject serializer(
+ *module, targetTriple, "", "", 3, options.getInitialLlvmIRCallback(),
+ options.getLinkedLlvmIRCallback(), options.getOptimizedLlvmIRCallback());
return serializer.run();
}
@@ -153,3 +155,91 @@ TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(TargetAttrAPI)) {
// `serializeToObject`.
ASSERT_TRUE(properties.contains("serialize_attr"));
}
+
+// Test callback function invoked with initial LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithInitialLLVMIR)) {
+ MLIRContext context(registry);
+ context.loadAllAvailableDialects();
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ std::string initialLLVMIR;
+ auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
+ llvm::raw_string_ostream ros(initialLLVMIR);
+ module.print(ros, nullptr);
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
+ initialCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary != std::nullopt);
+ ASSERT_TRUE(!serializedBinary->empty());
+ ASSERT_TRUE(!initialLLVMIR.empty());
+}
+
+// Test callback function invoked with linked LLVM IR
+TEST_F(MLIRTargetLLVM, SKIP_WITHOUT_NATIVE(CallbackInvokedWithLinkedLLVMIR)) {
+ MLIRContext context(registry);
+ context.loadAllAvailableDialects();
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ std::string linkedLLVMIR;
+ auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
+ llvm::raw_string_ostream ros(linkedLLVMIR);
+ module.print(ros, nullptr);
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
+ {}, linkedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary != std::nullopt);
+ ASSERT_TRUE(!serializedBinary->empty());
+ ASSERT_TRUE(!linkedLLVMIR.empty());
+}
+
+// Test callback function invoked with optimized LLVM IR
+TEST_F(MLIRTargetLLVM,
+ SKIP_WITHOUT_NATIVE(CallbackInvokedWithOptimizedLLVMIR)) {
+ MLIRContext context(registry);
+ context.loadAllAvailableDialects();
+
+ OwningOpRef<ModuleOp> module =
+ parseSourceString<ModuleOp>(moduleStr, &context);
+ ASSERT_TRUE(!!module);
+ Builder builder(&context);
+ IntegerAttr target = builder.getI32IntegerAttr(0);
+ auto targetAttr = dyn_cast<gpu::TargetAttrInterface>(target);
+
+ std::string optimizedLLVMIR;
+ auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
+ llvm::raw_string_ostream ros(optimizedLLVMIR);
+ module.print(ros, nullptr);
+ };
+
+ gpu::TargetOptions opts(
+ {}, {}, {}, mlir::gpu::TargetOptions::getDefaultCompilationTarget(), {},
+ {}, {}, optimizedCallback);
+ std::optional<SmallVector<char, 0>> serializedBinary =
+ targetAttr.serializeToObject(*module, opts);
+
+ ASSERT_TRUE(serializedBinary != std::nullopt);
+ ASSERT_TRUE(!serializedBinary->empty());
+ ASSERT_TRUE(!optimizedLLVMIR.empty());
+}
\ No newline at end of file
|
2200764
to
2175a3f
Compare
Here is the merged MR which caused a failure and was reverted.
Thanks to @joker-eph for the help, I fix it (miss constructing
ModuleObject
with callback functions inmlir/lib/Target/LLVM/NVVM/Target.cpp
) and split unit tests from origin test which don't needptxas
to make the test runs more widely.