Skip to content

Commit 69040d5

Browse files
author
Stephan Herhut
committed
[MLIR] Allow for multiple gpu modules during translation.
This change makes the ModuleTranslation threadsafe by locking on the LLVMContext. Furthermore, we now clone the llvm module into a new context when compiling to PTX similar to what the OrcJit does. Differential Revision: https://reviews.llvm.org/D78207
1 parent da20740 commit 69040d5

File tree

11 files changed

+90
-23
lines changed

11 files changed

+90
-23
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
namespace llvm {
3333
class Type;
3434
class LLVMContext;
35+
namespace sys {
36+
template <bool mt_only>
37+
class SmartMutex;
38+
} // end namespace sys
3539
} // end namespace llvm
3640

3741
namespace mlir {
@@ -216,6 +220,12 @@ Value createGlobalString(Location loc, OpBuilder &builder, StringRef name,
216220
/// function confirms that the Operation has the desired properties.
217221
bool satisfiesLLVMModule(Operation *op);
218222

223+
/// Clones the given module into the provided context. This is implemented by
224+
/// transforming the module into bitcode and then reparsing the bitcode in the
225+
/// provided context.
226+
std::unique_ptr<llvm::Module>
227+
cloneModuleIntoNewContext(llvm::LLVMContext *context, llvm::Module *module);
228+
219229
} // end namespace LLVM
220230
} // end namespace mlir
221231

mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def LLVM_Dialect : Dialect {
2424
~LLVMDialect();
2525
llvm::LLVMContext &getLLVMContext();
2626
llvm::Module &getLLVMModule();
27+
llvm::sys::SmartMutex<true> &getLLVMContextMutex();
2728

2829
private:
2930
friend LLVMType;

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,14 +106,15 @@ class ModuleTranslation {
106106
/// Original and translated module.
107107
Operation *mlirModule;
108108
std::unique_ptr<llvm::Module> llvmModule;
109-
110109
/// A converter for translating debug information.
111110
std::unique_ptr<detail::DebugTranslation> debugTranslation;
112111

113112
/// Builder for LLVM IR generation of OpenMP constructs.
114113
std::unique_ptr<llvm::OpenMPIRBuilder> ompBuilder;
115114
/// Precomputed pointer to OpenMP dialect.
116115
const Dialect *ompDialect;
116+
/// Pointer to the llvmDialect;
117+
LLVMDialect *llvmDialect;
117118

118119
/// Mappings between llvm.mlir.global definitions and corresponding globals.
119120
DenseMap<Operation *, llvm::GlobalValue *> globalsMapping;

mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"
1616

1717
#include "mlir/Dialect/GPU/GPUDialect.h"
18+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1819
#include "mlir/IR/Attributes.h"
1920
#include "mlir/IR/Builders.h"
2021
#include "mlir/IR/Function.h"
@@ -98,12 +99,19 @@ std::string GpuKernelToCubinPass::translateModuleToPtx(
9899
llvm::Module &module, llvm::TargetMachine &target_machine) {
99100
std::string ptx;
100101
{
102+
// Clone the llvm module into a new context to enable concurrent compilation
103+
// with multiple threads.
104+
// TODO(zinenko): Reevaluate model of ownership of LLVMContext in
105+
// LLVMDialect.
106+
llvm::LLVMContext llvmContext;
107+
auto clone = LLVM::cloneModuleIntoNewContext(&llvmContext, &module);
108+
101109
llvm::raw_string_ostream stream(ptx);
102110
llvm::buffer_ostream pstream(stream);
103111
llvm::legacy::PassManager codegen_passes;
104112
target_machine.addPassesToEmitFile(codegen_passes, pstream, nullptr,
105113
llvm::CGFT_AssemblyFile);
106-
codegen_passes.run(module);
114+
codegen_passes.run(*clone);
107115
}
108116

109117
return ptx;

mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ class GpuLaunchFuncToCudaCallsPass
116116
void addParamToList(OpBuilder &builder, Location loc, Value param, Value list,
117117
unsigned pos, Value one);
118118
Value setupParamsArray(gpu::LaunchFuncOp launchOp, OpBuilder &builder);
119-
Value generateKernelNameConstant(StringRef name, Location loc,
120-
OpBuilder &builder);
119+
Value generateKernelNameConstant(StringRef moduleName, StringRef name,
120+
Location loc, OpBuilder &builder);
121121
void translateGpuLaunchCalls(mlir::gpu::LaunchFuncOp launchOp);
122122

123123
public:
@@ -345,12 +345,13 @@ Value GpuLaunchFuncToCudaCallsPass::setupParamsArray(gpu::LaunchFuncOp launchOp,
345345
// %2 = llvm.getelementptr %0[%1, %1] : !llvm<"i8*">
346346
// }
347347
Value GpuLaunchFuncToCudaCallsPass::generateKernelNameConstant(
348-
StringRef name, Location loc, OpBuilder &builder) {
348+
StringRef moduleName, StringRef name, Location loc, OpBuilder &builder) {
349349
// Make sure the trailing zero is included in the constant.
350350
std::vector<char> kernelName(name.begin(), name.end());
351351
kernelName.push_back('\0');
352352

353-
std::string globalName = std::string(llvm::formatv("{0}_kernel_name", name));
353+
std::string globalName =
354+
std::string(llvm::formatv("{0}_{1}_kernel_name", moduleName, name));
354355
return LLVM::createGlobalString(
355356
loc, builder, globalName, StringRef(kernelName.data(), kernelName.size()),
356357
LLVM::Linkage::Internal, llvmDialect);
@@ -415,7 +416,8 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
415416
// the kernel function.
416417
auto cuOwningModuleRef =
417418
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
418-
auto kernelName = generateKernelNameConstant(launchOp.kernel(), loc, builder);
419+
auto kernelName = generateKernelNameConstant(launchOp.getKernelModuleName(),
420+
launchOp.kernel(), loc, builder);
419421
auto cuFunction = allocatePointer(builder, loc);
420422
auto cuModuleGetFunction =
421423
getOperation().lookupSymbol<LLVM::LLVMFuncOp>(cuModuleGetFunctionName);

mlir/lib/Dialect/LLVMIR/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ add_mlir_dialect_library(MLIRLLVMIR
1313
target_link_libraries(MLIRLLVMIR
1414
PUBLIC
1515
LLVMAsmParser
16+
LLVMBitReader
17+
LLVMBitWriter
1618
LLVMCore
1719
LLVMSupport
1820
LLVMFrontendOpenMP

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
#include "llvm/ADT/StringSwitch.h"
2222
#include "llvm/AsmParser/Parser.h"
23+
#include "llvm/Bitcode/BitcodeReader.h"
24+
#include "llvm/Bitcode/BitcodeWriter.h"
2325
#include "llvm/IR/Attributes.h"
2426
#include "llvm/IR/Function.h"
2527
#include "llvm/IR/Type.h"
@@ -1682,6 +1684,9 @@ LLVMDialect::~LLVMDialect() {}
16821684

16831685
llvm::LLVMContext &LLVMDialect::getLLVMContext() { return impl->llvmContext; }
16841686
llvm::Module &LLVMDialect::getLLVMModule() { return impl->module; }
1687+
llvm::sys::SmartMutex<true> &LLVMDialect::getLLVMContextMutex() {
1688+
return impl->mutex;
1689+
}
16851690

16861691
/// Parse a type registered to this dialect.
16871692
Type LLVMDialect::parseType(DialectAsmParser &parser) const {
@@ -1971,3 +1976,16 @@ bool mlir::LLVM::satisfiesLLVMModule(Operation *op) {
19711976
return op->hasTrait<OpTrait::SymbolTable>() &&
19721977
op->hasTrait<OpTrait::IsIsolatedFromAbove>();
19731978
}
1979+
1980+
std::unique_ptr<llvm::Module>
1981+
mlir::LLVM::cloneModuleIntoNewContext(llvm::LLVMContext *context,
1982+
llvm::Module *module) {
1983+
SmallVector<char, 1> buffer;
1984+
{
1985+
llvm::raw_svector_ostream os(buffer);
1986+
WriteBitcodeToFile(*module, os);
1987+
}
1988+
llvm::MemoryBufferRef bufferRef(StringRef(buffer.data(), buffer.size()),
1989+
"cloned module buffer");
1990+
return cantFail(parseBitcodeFile(bufferRef, *context));
1991+
}

mlir/lib/ExecutionEngine/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ target_link_libraries(MLIRExecutionEngine
1717
PUBLIC
1818
MLIRLLVMIR
1919
MLIRTargetLLVMIR
20-
LLVMBitReader
21-
LLVMBitWriter
2220
LLVMExecutionEngine
2321
LLVMObject
2422
LLVMOrcJIT

mlir/lib/ExecutionEngine/ExecutionEngine.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@
1111
//
1212
//===----------------------------------------------------------------------===//
1313
#include "mlir/ExecutionEngine/ExecutionEngine.h"
14+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1415
#include "mlir/IR/Function.h"
1516
#include "mlir/IR/Module.h"
1617
#include "mlir/Support/FileUtilities.h"
1718
#include "mlir/Target/LLVMIR.h"
1819

19-
#include "llvm/Bitcode/BitcodeReader.h"
20-
#include "llvm/Bitcode/BitcodeWriter.h"
2120
#include "llvm/ExecutionEngine/JITEventListener.h"
2221
#include "llvm/ExecutionEngine/ObjectCache.h"
2322
#include "llvm/ExecutionEngine/Orc/CompileUtils.h"
@@ -211,17 +210,8 @@ Expected<std::unique_ptr<ExecutionEngine>> ExecutionEngine::create(
211210
// Clone module in a new LLVMContext since translateModuleToLLVMIR buries
212211
// ownership too deeply.
213212
// TODO(zinenko): Reevaluate model of ownership of LLVMContext in LLVMDialect.
214-
SmallVector<char, 1> buffer;
215-
{
216-
llvm::raw_svector_ostream os(buffer);
217-
WriteBitcodeToFile(*llvmModule, os);
218-
}
219-
llvm::MemoryBufferRef bufferRef(StringRef(buffer.data(), buffer.size()),
220-
"cloned module buffer");
221-
auto expectedModule = parseBitcodeFile(bufferRef, *ctx);
222-
if (!expectedModule)
223-
return expectedModule.takeError();
224-
std::unique_ptr<Module> deserModule = std::move(*expectedModule);
213+
std::unique_ptr<Module> deserModule =
214+
LLVM::cloneModuleIntoNewContext(ctx.get(), llvmModule.get());
225215
auto dataLayout = deserModule->getDataLayout();
226216

227217
// Callback to create the object layer with symbol resolution to current

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ ModuleTranslation::ModuleTranslation(Operation *module,
301301
debugTranslation(
302302
std::make_unique<DebugTranslation>(module, *this->llvmModule)),
303303
ompDialect(
304-
module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()) {
304+
module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
305+
llvmDialect(module->getContext()->getRegisteredDialect<LLVMDialect>()) {
305306
assert(satisfiesLLVMModule(mlirModule) &&
306307
"mlirModule should honor LLVM's module semantics.");
307308
}
@@ -495,6 +496,9 @@ LogicalResult ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
495496
/// Create named global variables that correspond to llvm.mlir.global
496497
/// definitions.
497498
LogicalResult ModuleTranslation::convertGlobals() {
499+
// Lock access to the llvm context.
500+
llvm::sys::SmartScopedLock<true> scopedLock(
501+
llvmDialect->getLLVMContextMutex());
498502
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
499503
llvm::Type *type = op.getType().getUnderlyingType();
500504
llvm::Constant *cst = llvm::UndefValue::get(type);
@@ -754,6 +758,9 @@ LogicalResult ModuleTranslation::checkSupportedModuleOps(Operation *m) {
754758
}
755759

756760
LogicalResult ModuleTranslation::convertFunctions() {
761+
// Lock access to the llvm context.
762+
llvm::sys::SmartScopedLock<true> scopedLock(
763+
llvmDialect->getLLVMContextMutex());
757764
// Declare all functions first because there may be function calls that form a
758765
// call graph with cycles.
759766
for (auto function : getModuleBody(mlirModule).getOps<LLVMFuncOp>()) {
@@ -798,6 +805,8 @@ std::unique_ptr<llvm::Module>
798805
ModuleTranslation::prepareLLVMModule(Operation *m) {
799806
auto *dialect = m->getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
800807
assert(dialect && "LLVM dialect must be registered");
808+
// Lock the LLVM context as we might create new types here.
809+
llvm::sys::SmartScopedLock<true> scopedLock(dialect->getLLVMContextMutex());
801810

802811
auto llvmModule = llvm::CloneModule(dialect->getLLVMModule());
803812
if (!llvmModule)
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// RUN: mlir-cuda-runner %s --print-ir-after-all --shared-libs=%cuda_wrapper_library_dir/libcuda-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s --dump-input=always
2+
3+
// CHECK: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
4+
func @main() {
5+
%arg = alloc() : memref<13xi32>
6+
%dst = memref_cast %arg : memref<13xi32> to memref<?xi32>
7+
%one = constant 1 : index
8+
%sx = dim %dst, 0 : memref<?xi32>
9+
call @mcuMemHostRegisterMemRef1dInt32(%dst) : (memref<?xi32>) -> ()
10+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
11+
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
12+
%t0 = index_cast %tx : index to i32
13+
store %t0, %dst[%tx] : memref<?xi32>
14+
gpu.terminator
15+
}
16+
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %one, %grid_y = %one, %grid_z = %one)
17+
threads(%tx, %ty, %tz) in (%block_x = %sx, %block_y = %one, %block_z = %one) {
18+
%t0 = index_cast %tx : index to i32
19+
store %t0, %dst[%tx] : memref<?xi32>
20+
gpu.terminator
21+
}
22+
%U = memref_cast %dst : memref<?xi32> to memref<*xi32>
23+
call @print_memref_i32(%U) : (memref<*xi32>) -> ()
24+
return
25+
}
26+
27+
func @mcuMemHostRegisterMemRef1dInt32(%ptr : memref<?xi32>)
28+
func @print_memref_i32(%ptr : memref<*xi32>)

0 commit comments

Comments
 (0)