Skip to content

[MLIR] Add callback functions for ModuleToObject #116007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H

#include "mlir/IR/Attributes.h"
#include "llvm/IR/Module.h"

namespace llvm {
class IRBuilderBase;
Expand Down Expand Up @@ -52,7 +53,11 @@ class TargetOptions {
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {});
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});

/// Returns the typeID.
TypeID getTypeID() const;
Expand Down Expand Up @@ -80,6 +85,22 @@ class TargetOptions {
/// table.
SymbolTable *getSymbolTable() const;

/// Returns the callback invoked with the initial LLVM IR for the device
/// module.
function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module
/// after linking the device libraries.
function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;

/// Returns the callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;

/// Returns the callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> getISACallback() const;

/// Returns the default compilation target: `CompilationTarget::Fatbin`.
static CompilationTarget getDefaultCompilationTarget();

Expand All @@ -90,7 +111,11 @@ class TargetOptions {
TypeID typeID, StringRef toolkitPath = {},
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
function_ref<SymbolTable *()> getSymbolTableCallback = {});
function_ref<SymbolTable *()> getSymbolTableCallback = {},
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});

/// Path to the target toolkit.
std::string toolkitPath;
Expand All @@ -109,6 +134,21 @@ class TargetOptions {
/// being serialized.
function_ref<SymbolTable *()> getSymbolTableCallback;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;

private:
TypeID typeID;
};
Expand Down
24 changes: 22 additions & 2 deletions mlir/include/mlir/Target/LLVM/ModuleToObject.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,13 @@ class ModuleTranslation;
/// operations being transformed must be translatable into LLVM IR.
class ModuleToObject {
public:
ModuleToObject(Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3);
ModuleToObject(
Operation &module, StringRef triple, StringRef chip,
StringRef features = {}, int optLevel = 3,
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
function_ref<void(StringRef)> isaCallback = {});
virtual ~ModuleToObject();

/// Returns the operation being serialized.
Expand Down Expand Up @@ -114,6 +119,21 @@ class ModuleToObject {
/// Optimization level.
int optLevel;

/// Callback invoked with the initial LLVM IR for the device module.
function_ref<void(llvm::Module &)> initialLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// linking the device libraries.
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;

/// Callback invoked with LLVM IR for the device module after
/// LLVM optimizations but before codegen.
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;

/// Callback invoked with the target ISA for the device,
/// for example PTX assembly.
function_ref<void(StringRef)> isaCallback;

private:
/// The TargetMachine created for the given Triple, if available.
/// Accessible through `getOrCreateTargetMachine()`.
Expand Down
41 changes: 37 additions & 4 deletions mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
TargetOptions::TargetOptions(
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback)
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
cmdOptions, compilationTarget, getSymbolTableCallback) {}
cmdOptions, compilationTarget, getSymbolTableCallback,
initialLlvmIRCallback, linkedLlvmIRCallback,
optimizedLlvmIRCallback, isaCallback) {}

TargetOptions::TargetOptions(
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
StringRef cmdOptions, CompilationTarget compilationTarget,
function_ref<SymbolTable *()> getSymbolTableCallback)
function_ref<SymbolTable *()> getSymbolTableCallback,
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
getSymbolTableCallback(getSymbolTableCallback),
initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
isaCallback(isaCallback), typeID(typeID) {}

TypeID TargetOptions::getTypeID() const { return typeID; }

Expand All @@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
}

function_ref<void(llvm::Module &)>
TargetOptions::getInitialLlvmIRCallback() const {
return initialLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
TargetOptions::getLinkedLlvmIRCallback() const {
return linkedLlvmIRCallback;
}

function_ref<void(llvm::Module &)>
TargetOptions::getOptimizedLlvmIRCallback() const {
return optimizedLlvmIRCallback;
}

function_ref<void(StringRef)> TargetOptions::getISACallback() const {
return isaCallback;
}

CompilationTarget TargetOptions::getCompilationTarget() const {
return compilationTarget;
}
Expand Down
22 changes: 19 additions & 3 deletions mlir/lib/Target/LLVM/ModuleToObject.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,17 @@
using namespace mlir;
using namespace mlir::LLVM;

ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
StringRef chip, StringRef features, int optLevel)
ModuleToObject::ModuleToObject(
Operation &module, StringRef triple, StringRef chip, StringRef features,
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
function_ref<void(StringRef)> isaCallback)
: module(module), triple(triple), chip(chip), features(features),
optLevel(optLevel) {}
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
linkedLlvmIRCallback(linkedLlvmIRCallback),
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
isaCallback(isaCallback) {}

ModuleToObject::~ModuleToObject() = default;

Expand Down Expand Up @@ -215,6 +222,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
}
setDataLayoutAndTriple(*llvmModule);

if (initialLlvmIRCallback)
initialLlvmIRCallback(*llvmModule);

// Link bitcode files.
handleModulePreLink(*llvmModule);
{
Expand All @@ -227,10 +237,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
handleModulePostLink(*llvmModule);
}

if (linkedLlvmIRCallback)
linkedLlvmIRCallback(*llvmModule);

// Optimize the module.
if (failed(optimizeModule(*llvmModule, optLevel)))
return std::nullopt;

if (optimizedLlvmIRCallback)
optimizedLlvmIRCallback(*llvmModule);

// Return the serialized object.
return moduleToObject(*llvmModule);
}
3 changes: 3 additions & 0 deletions mlir/lib/Target/LLVM/NVVM/Target.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
getOperation().emitError() << "Failed translating the module to ISA.";
return std::nullopt;
}
if (isaCallback)
isaCallback(serializedISA.value());

#define DEBUG_TYPE "serialize-to-isa"
LLVM_DEBUG({
llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";
Expand Down
62 changes: 62 additions & 0 deletions mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,65 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) {
ASSERT_TRUE(!object->empty());
}
}

// Test callback functions invoked with LLVM IR and ISA.
TEST_F(MLIRTargetLLVMNVVM,
SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) {
if (!hasPtxas())
GTEST_SKIP() << "PTXAS compiler not found, skipping test.";

MLIRContext context(registry);

OwningOpRef<ModuleOp> module =
parseSourceString<ModuleOp>(moduleStr, &context);
ASSERT_TRUE(!!module);

NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);

auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
ASSERT_TRUE(!!serializer);

std::string initialLLVMIR;
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
llvm::raw_string_ostream ros(initialLLVMIR);
module.print(ros, nullptr);
};

std::string linkedLLVMIR;
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
llvm::raw_string_ostream ros(linkedLLVMIR);
module.print(ros, nullptr);
};

std::string optimizedLLVMIR;
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
llvm::raw_string_ostream ros(optimizedLLVMIR);
module.print(ros, nullptr);
};

std::string isaResult;
auto isaCallback = [&isaResult](llvm::StringRef isa) {
isaResult = isa.str();
};

gpu::TargetOptions options({}, {}, {}, gpu::CompilationTarget::Binary, {},
initialCallback, linkedCallback, optimizedCallback,
isaCallback);

for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
std::optional<SmallVector<char, 0>> object =
serializer.serializeToObject(gpuModule, options);

ASSERT_TRUE(object != std::nullopt);
ASSERT_TRUE(!object->empty());
ASSERT_TRUE(!initialLLVMIR.empty());
ASSERT_TRUE(!linkedLLVMIR.empty());
ASSERT_TRUE(!optimizedLLVMIR.empty());
ASSERT_TRUE(!isaResult.empty());

initialLLVMIR.clear();
linkedLLVMIR.clear();
optimizedLLVMIR.clear();
isaResult.clear();
}
}
Loading