Skip to content

Commit 2da417e

Browse files
[mlir][GPU] gpu.printf: Do not emit duplicate format strings (#110504)
Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.
1 parent 8897dd6 commit 2da417e

File tree

2 files changed

+56
-57
lines changed

2 files changed

+56
-57
lines changed

mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp

Lines changed: 49 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,34 @@ static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
340340
return stringConstName;
341341
}
342342

343+
/// Create an global that contains the given format string. If a global with
344+
/// the same format string exists already in the module, return that global.
345+
static LLVM::GlobalOp getOrCreateFormatStringConstant(
346+
OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
347+
StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
348+
llvm::SmallString<20> formatString(str);
349+
formatString.push_back('\0'); // Null terminate for C
350+
auto globalType =
351+
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
352+
StringAttr attr = b.getStringAttr(formatString);
353+
354+
// Try to find existing global.
355+
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
356+
if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
357+
globalOp.getValueAttr() == attr &&
358+
globalOp.getAlignment().value_or(0) == alignment &&
359+
globalOp.getAddrSpace() == addrSpace)
360+
return globalOp;
361+
362+
// Not found: create new global.
363+
OpBuilder::InsertionGuard guard(b);
364+
b.setInsertionPointToStart(moduleOp.getBody());
365+
SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
366+
return b.create<LLVM::GlobalOp>(loc, globalType,
367+
/*isConstant=*/true, LLVM::Linkage::Internal,
368+
name, attr, alignment, addrSpace);
369+
}
370+
343371
template <typename T>
344372
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
345373
ConversionPatternRewriter &rewriter,
@@ -391,33 +419,20 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
391419
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
392420
Value printfDesc = printfBeginCall.getResult();
393421

394-
// Get a unique global name for the format.
395-
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
396-
397-
llvm::SmallString<20> formatString(adaptor.getFormat());
398-
formatString.push_back('\0'); // Null terminate for C
399-
size_t formatStringSize = formatString.size_in_bytes();
400-
401-
auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
402-
LLVM::GlobalOp global;
403-
{
404-
ConversionPatternRewriter::InsertionGuard guard(rewriter);
405-
rewriter.setInsertionPointToStart(moduleOp.getBody());
406-
global = rewriter.create<LLVM::GlobalOp>(
407-
loc, globalType,
408-
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
409-
rewriter.getStringAttr(formatString));
410-
}
422+
// Create the global op or find an existing one.
423+
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
424+
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
411425

412426
// Get a pointer to the format string's first element and pass it to printf()
413427
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
414428
loc,
415429
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
416430
global.getSymNameAttr());
417-
Value stringStart = rewriter.create<LLVM::GEPOp>(
418-
loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
419-
Value stringLen =
420-
rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
431+
Value stringStart =
432+
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
433+
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
434+
Value stringLen = rewriter.create<LLVM::ConstantOp>(
435+
loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
421436

422437
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
423438
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
@@ -486,30 +501,19 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
486501
LLVM::LLVMFuncOp printfDecl =
487502
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
488503

489-
// Get a unique global name for the format.
490-
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
491-
492-
llvm::SmallString<20> formatString(adaptor.getFormat());
493-
formatString.push_back('\0'); // Null terminate for C
494-
auto globalType =
495-
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
496-
LLVM::GlobalOp global;
497-
{
498-
ConversionPatternRewriter::InsertionGuard guard(rewriter);
499-
rewriter.setInsertionPointToStart(moduleOp.getBody());
500-
global = rewriter.create<LLVM::GlobalOp>(
501-
loc, globalType,
502-
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
503-
rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
504-
}
504+
// Create the global op or find an existing one.
505+
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
506+
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
507+
addressSpace);
505508

506509
// Get a pointer to the format string's first element
507510
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
508511
loc,
509512
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
510513
global.getSymNameAttr());
511-
Value stringStart = rewriter.create<LLVM::GEPOp>(
512-
loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
514+
Value stringStart =
515+
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
516+
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
513517

514518
// Construct arguments and function call
515519
auto argsRange = adaptor.getArgs();
@@ -541,27 +545,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
541545
LLVM::LLVMFuncOp vprintfDecl =
542546
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
543547

544-
// Get a unique global name for the format.
545-
SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
546-
547-
llvm::SmallString<20> formatString(adaptor.getFormat());
548-
formatString.push_back('\0'); // Null terminate for C
549-
auto globalType =
550-
LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
551-
LLVM::GlobalOp global;
552-
{
553-
ConversionPatternRewriter::InsertionGuard guard(rewriter);
554-
rewriter.setInsertionPointToStart(moduleOp.getBody());
555-
global = rewriter.create<LLVM::GlobalOp>(
556-
loc, globalType,
557-
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
558-
rewriter.getStringAttr(formatString), /*allignment=*/0);
559-
}
548+
// Create the global op or find an existing one.
549+
LLVM::GlobalOp global = getOrCreateFormatStringConstant(
550+
rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
560551

561552
// Get a pointer to the format string's first element
562553
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
563-
Value stringStart = rewriter.create<LLVM::GEPOp>(
564-
loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
554+
Value stringStart =
555+
rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
556+
globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
565557
SmallVector<Type> types;
566558
SmallVector<Value> args;
567559
// Promote and pack the arguments into a stack allocation.

mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,13 @@ gpu.module @test_module_29 {
610610
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
611611
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
612612
gpu.printf "Hello, world\n"
613+
614+
// Make sure that the same global is reused.
615+
// CHECK: %[[FORMATSTR2:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
616+
// CHECK: %[[FORMATSTART2:.*]] = llvm.getelementptr %[[FORMATSTR2]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
617+
// CHECK: llvm.call @vprintf(%[[FORMATSTART2]], %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
618+
gpu.printf "Hello, world\n"
619+
613620
gpu.return
614621
}
615622

0 commit comments

Comments
 (0)