Skip to content

[mlir][GPU] gpu.printf: Do not emit duplicate format strings #110504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 49 additions & 57 deletions mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,34 @@ static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
return stringConstName;
}

/// Create an global that contains the given format string. If a global with
/// the same format string exists already in the module, return that global.
static LLVM::GlobalOp getOrCreateFormatStringConstant(
OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
llvm::SmallString<20> formatString(str);
formatString.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
StringAttr attr = b.getStringAttr(formatString);

// Try to find existing global.
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
Copy link
Member

@grypp grypp Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you look at the getDynamicSharedMemorySymbol below, it does something very similar. Can we unify them?

I don't want to block the PR because of this, but I'd like to have this unification at some point.

if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
globalOp.getValueAttr() == attr &&
globalOp.getAlignment().value_or(0) == alignment &&
globalOp.getAddrSpace() == addrSpace)
return globalOp;

// Not found: create new global.
OpBuilder::InsertionGuard guard(b);
b.setInsertionPointToStart(moduleOp.getBody());
SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
return b.create<LLVM::GlobalOp>(loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal,
name, attr, alignment, addrSpace);
}

template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
Expand Down Expand Up @@ -391,33 +419,20 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();

// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);

llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
size_t formatStringSize = formatString.size_in_bytes();

auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString));
}
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());

// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen =
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
Value stringStart =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringLen = rewriter.create<LLVM::ConstantOp>(
loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());

Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
Expand Down Expand Up @@ -486,30 +501,19 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFuncOp printfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);

// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);

llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
}
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
addressSpace);

// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringStart =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});

// Construct arguments and function call
auto argsRange = adaptor.getArgs();
Expand Down Expand Up @@ -541,27 +545,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
LLVM::LLVMFuncOp vprintfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);

// Get a unique global name for the format.
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);

llvm::SmallString<20> formatString(adaptor.getFormat());
formatString.push_back('\0'); // Null terminate for C
auto globalType =
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
loc, globalType,
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
rewriter.getStringAttr(formatString), /*allignment=*/0);
}
// Create the global op or find an existing one.
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());

// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
Value stringStart = rewriter.create<LLVM::GEPOp>(
loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
Value stringStart =
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
Expand Down
7 changes: 7 additions & 0 deletions mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,13 @@ gpu.module @test_module_29 {
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello, world\n"

// Make sure that the same global is reused.
// CHECK: %[[FORMATSTR2:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
// CHECK: %[[FORMATSTART2:.*]] = llvm.getelementptr %[[FORMATSTR2]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
// CHECK: llvm.call @vprintf(%[[FORMATSTART2]], %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello, world\n"

gpu.return
}

Expand Down
Loading