Skip to content

Commit 62707ce

Browse files
committed
Add callback functions for ModuleToObject
1 parent 9571cc2 commit 62707ce

File tree

6 files changed

+185
-11
lines changed

6 files changed

+185
-11
lines changed

mlir/include/mlir/Dialect/GPU/IR/CompilationInterfaces.h

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#define MLIR_DIALECT_GPU_IR_COMPILATIONINTERFACES_H
1515

1616
#include "mlir/IR/Attributes.h"
17+
#include "llvm/IR/Module.h"
1718

1819
namespace llvm {
1920
class IRBuilderBase;
@@ -52,7 +53,11 @@ class TargetOptions {
5253
StringRef toolkitPath = {}, ArrayRef<std::string> linkFiles = {},
5354
StringRef cmdOptions = {},
5455
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
55-
function_ref<SymbolTable *()> getSymbolTableCallback = {});
56+
function_ref<SymbolTable *()> getSymbolTableCallback = {},
57+
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
58+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
59+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
60+
function_ref<void(StringRef)> isaCallback = {});
5661

5762
/// Returns the typeID.
5863
TypeID getTypeID() const;
@@ -80,6 +85,22 @@ class TargetOptions {
8085
/// table.
8186
SymbolTable *getSymbolTable() const;
8287

88+
/// Returns the callback invoked with the initial LLVM IR for the device
89+
/// module.
90+
function_ref<void(llvm::Module &)> getInitialLlvmIRCallback() const;
91+
92+
/// Returns the callback invoked with LLVM IR for the device module
93+
/// after linking the device libraries.
94+
function_ref<void(llvm::Module &)> getLinkedLlvmIRCallback() const;
95+
96+
/// Returns the callback invoked with LLVM IR for the device module after
97+
/// LLVM optimizations but before codegen.
98+
function_ref<void(llvm::Module &)> getOptimizedLlvmIRCallback() const;
99+
100+
/// Returns the callback invoked with the target ISA for the device,
101+
/// for example PTX assembly.
102+
function_ref<void(StringRef)> getISACallback() const;
103+
83104
/// Returns the default compilation target: `CompilationTarget::Fatbin`.
84105
static CompilationTarget getDefaultCompilationTarget();
85106

@@ -90,7 +111,11 @@ class TargetOptions {
90111
TypeID typeID, StringRef toolkitPath = {},
91112
ArrayRef<std::string> linkFiles = {}, StringRef cmdOptions = {},
92113
CompilationTarget compilationTarget = getDefaultCompilationTarget(),
93-
function_ref<SymbolTable *()> getSymbolTableCallback = {});
114+
function_ref<SymbolTable *()> getSymbolTableCallback = {},
115+
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
116+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
117+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
118+
function_ref<void(StringRef)> isaCallback = {});
94119

95120
/// Path to the target toolkit.
96121
std::string toolkitPath;
@@ -109,6 +134,21 @@ class TargetOptions {
109134
/// being serialized.
110135
function_ref<SymbolTable *()> getSymbolTableCallback;
111136

137+
/// Callback invoked with the initial LLVM IR for the device module.
138+
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
139+
140+
/// Callback invoked with LLVM IR for the device module after
141+
/// linking the device libraries.
142+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
143+
144+
/// Callback invoked with LLVM IR for the device module after
145+
/// LLVM optimizations but before codegen.
146+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
147+
148+
/// Callback invoked with the target ISA for the device,
149+
/// for example PTX assembly.
150+
function_ref<void(StringRef)> isaCallback;
151+
112152
private:
113153
TypeID typeID;
114154
};

mlir/include/mlir/Target/LLVM/ModuleToObject.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,13 @@ class ModuleTranslation;
2929
/// operations being transformed must be translatable into LLVM IR.
3030
class ModuleToObject {
3131
public:
32-
ModuleToObject(Operation &module, StringRef triple, StringRef chip,
33-
StringRef features = {}, int optLevel = 3);
32+
ModuleToObject(
33+
Operation &module, StringRef triple, StringRef chip,
34+
StringRef features = {}, int optLevel = 3,
35+
function_ref<void(llvm::Module &)> initialLlvmIRCallback = {},
36+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback = {},
37+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback = {},
38+
function_ref<void(StringRef)> isaCallback = {});
3439
virtual ~ModuleToObject();
3540

3641
/// Returns the operation being serialized.
@@ -114,6 +119,21 @@ class ModuleToObject {
114119
/// Optimization level.
115120
int optLevel;
116121

122+
/// Callback invoked with the initial LLVM IR for the device module.
123+
function_ref<void(llvm::Module &)> initialLlvmIRCallback;
124+
125+
/// Callback invoked with LLVM IR for the device module after
126+
/// linking the device libraries.
127+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback;
128+
129+
/// Callback invoked with LLVM IR for the device module after
130+
/// LLVM optimizations but before codegen.
131+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback;
132+
133+
/// Callback invoked with the target ISA for the device,
134+
/// for example PTX assembly.
135+
function_ref<void(StringRef)> isaCallback;
136+
117137
private:
118138
/// The TargetMachine created for the given Triple, if available.
119139
/// Accessible through `getOrCreateTargetMachine()`.

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2302,17 +2302,31 @@ KernelMetadataAttr KernelTableAttr::lookup(StringAttr key) const {
23022302
TargetOptions::TargetOptions(
23032303
StringRef toolkitPath, ArrayRef<std::string> linkFiles,
23042304
StringRef cmdOptions, CompilationTarget compilationTarget,
2305-
function_ref<SymbolTable *()> getSymbolTableCallback)
2305+
function_ref<SymbolTable *()> getSymbolTableCallback,
2306+
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2307+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2308+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2309+
function_ref<void(StringRef)> isaCallback)
23062310
: TargetOptions(TypeID::get<TargetOptions>(), toolkitPath, linkFiles,
2307-
cmdOptions, compilationTarget, getSymbolTableCallback) {}
2311+
cmdOptions, compilationTarget, getSymbolTableCallback,
2312+
initialLlvmIRCallback, linkedLlvmIRCallback,
2313+
optimizedLlvmIRCallback, isaCallback) {}
23082314

23092315
TargetOptions::TargetOptions(
23102316
TypeID typeID, StringRef toolkitPath, ArrayRef<std::string> linkFiles,
23112317
StringRef cmdOptions, CompilationTarget compilationTarget,
2312-
function_ref<SymbolTable *()> getSymbolTableCallback)
2318+
function_ref<SymbolTable *()> getSymbolTableCallback,
2319+
function_ref<void(llvm::Module &)> initialLlvmIRCallback,
2320+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
2321+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
2322+
function_ref<void(StringRef)> isaCallback)
23132323
: toolkitPath(toolkitPath.str()), linkFiles(linkFiles),
23142324
cmdOptions(cmdOptions.str()), compilationTarget(compilationTarget),
2315-
getSymbolTableCallback(getSymbolTableCallback), typeID(typeID) {}
2325+
getSymbolTableCallback(getSymbolTableCallback),
2326+
initialLlvmIRCallback(initialLlvmIRCallback),
2327+
linkedLlvmIRCallback(linkedLlvmIRCallback),
2328+
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
2329+
isaCallback(isaCallback), typeID(typeID) {}
23162330

23172331
TypeID TargetOptions::getTypeID() const { return typeID; }
23182332

@@ -2326,6 +2340,25 @@ SymbolTable *TargetOptions::getSymbolTable() const {
23262340
return getSymbolTableCallback ? getSymbolTableCallback() : nullptr;
23272341
}
23282342

2343+
function_ref<void(llvm::Module &)>
2344+
TargetOptions::getInitialLlvmIRCallback() const {
2345+
return initialLlvmIRCallback;
2346+
}
2347+
2348+
function_ref<void(llvm::Module &)>
2349+
TargetOptions::getLinkedLlvmIRCallback() const {
2350+
return linkedLlvmIRCallback;
2351+
}
2352+
2353+
function_ref<void(llvm::Module &)>
2354+
TargetOptions::getOptimizedLlvmIRCallback() const {
2355+
return optimizedLlvmIRCallback;
2356+
}
2357+
2358+
function_ref<void(StringRef)> TargetOptions::getISACallback() const {
2359+
return isaCallback;
2360+
}
2361+
23292362
CompilationTarget TargetOptions::getCompilationTarget() const {
23302363
return compilationTarget;
23312364
}

mlir/lib/Target/LLVM/ModuleToObject.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,17 @@
3434
using namespace mlir;
3535
using namespace mlir::LLVM;
3636

37-
ModuleToObject::ModuleToObject(Operation &module, StringRef triple,
38-
StringRef chip, StringRef features, int optLevel)
37+
ModuleToObject::ModuleToObject(
38+
Operation &module, StringRef triple, StringRef chip, StringRef features,
39+
int optLevel, function_ref<void(llvm::Module &)> initialLlvmIRCallback,
40+
function_ref<void(llvm::Module &)> linkedLlvmIRCallback,
41+
function_ref<void(llvm::Module &)> optimizedLlvmIRCallback,
42+
function_ref<void(StringRef)> isaCallback)
3943
: module(module), triple(triple), chip(chip), features(features),
40-
optLevel(optLevel) {}
44+
optLevel(optLevel), initialLlvmIRCallback(initialLlvmIRCallback),
45+
linkedLlvmIRCallback(linkedLlvmIRCallback),
46+
optimizedLlvmIRCallback(optimizedLlvmIRCallback),
47+
isaCallback(isaCallback) {}
4148

4249
ModuleToObject::~ModuleToObject() = default;
4350

@@ -215,6 +222,9 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
215222
}
216223
setDataLayoutAndTriple(*llvmModule);
217224

225+
if (initialLlvmIRCallback)
226+
initialLlvmIRCallback(*llvmModule);
227+
218228
// Link bitcode files.
219229
handleModulePreLink(*llvmModule);
220230
{
@@ -227,10 +237,16 @@ std::optional<SmallVector<char, 0>> ModuleToObject::run() {
227237
handleModulePostLink(*llvmModule);
228238
}
229239

240+
if (linkedLlvmIRCallback)
241+
linkedLlvmIRCallback(*llvmModule);
242+
230243
// Optimize the module.
231244
if (failed(optimizeModule(*llvmModule, optLevel)))
232245
return std::nullopt;
233246

247+
if (optimizedLlvmIRCallback)
248+
optimizedLlvmIRCallback(*llvmModule);
249+
234250
// Return the serialized object.
235251
return moduleToObject(*llvmModule);
236252
}

mlir/lib/Target/LLVM/NVVM/Target.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,9 @@ NVPTXSerializer::moduleToObject(llvm::Module &llvmModule) {
572572
getOperation().emitError() << "Failed translating the module to ISA.";
573573
return std::nullopt;
574574
}
575+
if (isaCallback)
576+
isaCallback(serializedISA.value());
577+
575578
#define DEBUG_TYPE "serialize-to-isa"
576579
LLVM_DEBUG({
577580
llvm::dbgs() << "PTX for module: " << getOperation().getNameAttr() << "\n";

mlir/unittests/Target/LLVM/SerializeNVVMTarget.cpp

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,65 @@ TEST_F(MLIRTargetLLVMNVVM, SKIP_WITHOUT_NVPTX(SerializeNVVMToBinary)) {
156156
ASSERT_TRUE(!object->empty());
157157
}
158158
}
159+
160+
// Test callback functions invoked with LLVM IR and ISA.
161+
TEST_F(MLIRTargetLLVMNVVM,
162+
SKIP_WITHOUT_NVPTX(CallbackInvokedWithLLVMIRAndISA)) {
163+
if (!hasPtxas())
164+
GTEST_SKIP() << "PTXAS compiler not found, skipping test.";
165+
166+
MLIRContext context(registry);
167+
168+
OwningOpRef<ModuleOp> module =
169+
parseSourceString<ModuleOp>(moduleStr, &context);
170+
ASSERT_TRUE(!!module);
171+
172+
NVVM::NVVMTargetAttr target = NVVM::NVVMTargetAttr::get(&context);
173+
174+
auto serializer = dyn_cast<gpu::TargetAttrInterface>(target);
175+
ASSERT_TRUE(!!serializer);
176+
177+
std::string initialLLVMIR;
178+
auto initialCallback = [&initialLLVMIR](llvm::Module &module) {
179+
llvm::raw_string_ostream ros(initialLLVMIR);
180+
module.print(ros, nullptr);
181+
};
182+
183+
std::string linkedLLVMIR;
184+
auto linkedCallback = [&linkedLLVMIR](llvm::Module &module) {
185+
llvm::raw_string_ostream ros(linkedLLVMIR);
186+
module.print(ros, nullptr);
187+
};
188+
189+
std::string optimizedLLVMIR;
190+
auto optimizedCallback = [&optimizedLLVMIR](llvm::Module &module) {
191+
llvm::raw_string_ostream ros(optimizedLLVMIR);
192+
module.print(ros, nullptr);
193+
};
194+
195+
std::string isaResult;
196+
auto isaCallback = [&isaResult](llvm::StringRef isa) {
197+
isaResult = isa.str();
198+
};
199+
200+
gpu::TargetOptions options({}, {}, {}, gpu::CompilationTarget::Binary, {},
201+
initialCallback, linkedCallback, optimizedCallback,
202+
isaCallback);
203+
204+
for (auto gpuModule : (*module).getBody()->getOps<gpu::GPUModuleOp>()) {
205+
std::optional<SmallVector<char, 0>> object =
206+
serializer.serializeToObject(gpuModule, options);
207+
208+
ASSERT_TRUE(object != std::nullopt);
209+
ASSERT_TRUE(!object->empty());
210+
ASSERT_TRUE(!initialLLVMIR.empty());
211+
ASSERT_TRUE(!linkedLLVMIR.empty());
212+
ASSERT_TRUE(!optimizedLLVMIR.empty());
213+
ASSERT_TRUE(!isaResult.empty());
214+
215+
initialLLVMIR.clear();
216+
linkedLLVMIR.clear();
217+
optimizedLLVMIR.clear();
218+
isaResult.clear();
219+
}
220+
}

0 commit comments

Comments
 (0)